fix: Refactor SQL username logic (#19914)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley
2022-05-12 21:03:05 -07:00
committed by GitHub
parent fff9ad05d4
commit 449d08b25e
22 changed files with 388 additions and 340 deletions

View File

@@ -20,13 +20,11 @@ import time
from contextlib import closing
from typing import Any, Dict, List, Optional
from flask import g
from superset import app, security_manager
from superset.models.core import Database
from superset.sql_parse import ParsedQuery
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
from superset.utils.core import QuerySource
from superset.utils.core import get_username, QuerySource
MAX_ERROR_ROWS = 10
@@ -45,7 +43,10 @@ class PrestoDBSQLValidator(BaseSQLValidator):
@classmethod
def validate_statement(
cls, statement: str, database: Database, cursor: Any, user_name: str
cls,
statement: str,
database: Database,
cursor: Any,
) -> Optional[SQLValidationAnnotation]:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
@@ -57,7 +58,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
if sql_query_mutator:
sql = sql_query_mutator(
sql,
user_name=user_name,
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
security_manager=security_manager,
database=database,
)
@@ -157,26 +158,18 @@ class PrestoDBSQLValidator(BaseSQLValidator):
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
user_name = g.user.username if g.user and hasattr(g.user, "username") else None
parsed_query = ParsedQuery(sql)
statements = parsed_query.get_statements()
logger.info("Validating %i statement(s)", len(statements))
engine = database.get_sqla_engine(
schema=schema,
nullpool=True,
user_name=user_name,
source=QuerySource.SQL_LAB,
)
engine = database.get_sqla_engine(schema, source=QuerySource.SQL_LAB)
# Sharing a single connection and cursor across the
# execution of all statements (if many)
annotations: List[SQLValidationAnnotation] = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in parsed_query.get_statements():
annotation = cls.validate_statement(
statement, database, cursor, user_name
)
annotation = cls.validate_statement(statement, database, cursor)
if annotation:
annotations.append(annotation)
logger.debug("Validation found %i error(s)", len(annotations))