[format] Using Black (#7769)

This commit is contained in:
John Bodley
2019-06-25 13:34:48 -07:00
committed by GitHub
parent 0c9e6d0985
commit 5c58fd1802
270 changed files with 15592 additions and 14772 deletions

View File

@@ -18,21 +18,13 @@
from contextlib import closing
import logging
import time
from typing import (
Any,
Dict,
List,
Optional,
)
from typing import Any, Dict, List, Optional
from flask import g
from superset import app, security_manager
from superset.sql_parse import ParsedQuery
from superset.sql_validators.base import (
BaseSQLValidator,
SQLValidationAnnotation,
)
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
from superset.utils.core import sources
MAX_ERROR_ROWS = 10
@@ -47,15 +39,11 @@ class PrestoSQLValidationError(Exception):
class PrestoDBSQLValidator(BaseSQLValidator):
"""Validate SQL queries using Presto's built-in EXPLAIN subtype"""
name = 'PrestoDBSQLValidator'
name = "PrestoDBSQLValidator"
@classmethod
def validate_statement(
cls,
statement,
database,
cursor,
user_name,
cls, statement, database, cursor, user_name
) -> Optional[SQLValidationAnnotation]:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
@@ -64,28 +52,29 @@ class PrestoDBSQLValidator(BaseSQLValidator):
# Hook to allow environment-specific mutation (usually comments) to the SQL
# pylint: disable=invalid-name
SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR")
if SQL_QUERY_MUTATOR:
sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
# Transform the final statement to an explain call before sending it on
# to presto to validate
sql = f'EXPLAIN (TYPE VALIDATE) {sql}'
sql = f"EXPLAIN (TYPE VALIDATE) {sql}"
# Invoke the query against presto. NB this deliberately doesn't use the
# engine spec's handle_cursor implementation since we don't record
# these EXPLAIN queries done in validation as proper Query objects
# in the superset ORM.
from pyhive.exc import DatabaseError
try:
db_engine_spec.execute(cursor, sql)
polled = cursor.poll()
while polled:
logging.info('polling presto for validation progress')
stats = polled.get('stats', {})
logging.info("polling presto for validation progress")
stats = polled.get("stats", {})
if stats:
state = stats.get('state')
if state == 'FINISHED':
state = stats.get("state")
if state == "FINISHED":
break
time.sleep(0.2)
polled = cursor.poll()
@@ -107,30 +96,29 @@ class PrestoDBSQLValidator(BaseSQLValidator):
# we update at some point in the future.
if not db_error.args or not isinstance(db_error.args[0], dict):
raise PrestoSQLValidationError(
'The pyhive presto client returned an unhandled '
'database error.',
"The pyhive presto client returned an unhandled " "database error."
) from db_error
error_args: Dict[str, Any] = db_error.args[0]
# Confirm the two fields we need to be able to present an annotation
# are present in the error response -- a message, and a location.
if 'message' not in error_args:
if "message" not in error_args:
raise PrestoSQLValidationError(
'The pyhive presto client did not report an error message',
"The pyhive presto client did not report an error message"
) from db_error
if 'errorLocation' not in error_args:
if "errorLocation" not in error_args:
raise PrestoSQLValidationError(
'The pyhive presto client did not report an error location',
"The pyhive presto client did not report an error location"
) from db_error
# Pylint is confused about the type of error_args, despite the hints
# and checks above.
# pylint: disable=invalid-sequence-index
message = error_args['message']
err_loc = error_args['errorLocation']
line_number = err_loc.get('lineNumber', None)
start_column = err_loc.get('columnNumber', None)
end_column = err_loc.get('columnNumber', None)
message = error_args["message"]
err_loc = error_args["errorLocation"]
line_number = err_loc.get("lineNumber", None)
start_column = err_loc.get("columnNumber", None)
end_column = err_loc.get("columnNumber", None)
return SQLValidationAnnotation(
message=message,
@@ -139,15 +127,12 @@ class PrestoDBSQLValidator(BaseSQLValidator):
end_column=end_column,
)
except Exception as e:
logging.exception(f'Unexpected error running validation query: {e}')
logging.exception(f"Unexpected error running validation query: {e}")
raise e
@classmethod
def validate(
cls,
sql: str,
schema: str,
database: Any,
cls, sql: str, schema: str, database: Any
) -> List[SQLValidationAnnotation]:
"""
Presto supports query-validation queries by running them with a
@@ -160,12 +145,12 @@ class PrestoDBSQLValidator(BaseSQLValidator):
parsed_query = ParsedQuery(sql)
statements = parsed_query.get_statements()
logging.info(f'Validating {len(statements)} statement(s)')
logging.info(f"Validating {len(statements)} statement(s)")
engine = database.get_sqla_engine(
schema=schema,
nullpool=True,
user_name=user_name,
source=sources.get('sql_lab', None),
source=sources.get("sql_lab", None),
)
# Sharing a single connection and cursor across the
# execution of all statements (if many)
@@ -174,13 +159,10 @@ class PrestoDBSQLValidator(BaseSQLValidator):
with closing(conn.cursor()) as cursor:
for statement in parsed_query.get_statements():
annotation = cls.validate_statement(
statement,
database,
cursor,
user_name,
statement, database, cursor, user_name
)
if annotation:
annotations.append(annotation)
logging.debug(f'Validation found {len(annotations)} error(s)')
logging.debug(f"Validation found {len(annotations)} error(s)")
return annotations