mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
fix: Refactor SQL username logic (#19914)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
@@ -20,13 +20,11 @@ import time
|
||||
from contextlib import closing
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask import g
|
||||
|
||||
from superset import app, security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
|
||||
from superset.utils.core import QuerySource
|
||||
from superset.utils.core import get_username, QuerySource
|
||||
|
||||
MAX_ERROR_ROWS = 10
|
||||
|
||||
@@ -45,7 +43,10 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
|
||||
@classmethod
|
||||
def validate_statement(
|
||||
cls, statement: str, database: Database, cursor: Any, user_name: str
|
||||
cls,
|
||||
statement: str,
|
||||
database: Database,
|
||||
cursor: Any,
|
||||
) -> Optional[SQLValidationAnnotation]:
|
||||
# pylint: disable=too-many-locals
|
||||
db_engine_spec = database.db_engine_spec
|
||||
@@ -57,7 +58,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
if sql_query_mutator:
|
||||
sql = sql_query_mutator(
|
||||
sql,
|
||||
user_name=user_name,
|
||||
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
|
||||
security_manager=security_manager,
|
||||
database=database,
|
||||
)
|
||||
@@ -157,26 +158,18 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
|
||||
VALIDATE) SELECT 1 FROM default.mytable.
|
||||
"""
|
||||
user_name = g.user.username if g.user and hasattr(g.user, "username") else None
|
||||
parsed_query = ParsedQuery(sql)
|
||||
statements = parsed_query.get_statements()
|
||||
|
||||
logger.info("Validating %i statement(s)", len(statements))
|
||||
engine = database.get_sqla_engine(
|
||||
schema=schema,
|
||||
nullpool=True,
|
||||
user_name=user_name,
|
||||
source=QuerySource.SQL_LAB,
|
||||
)
|
||||
engine = database.get_sqla_engine(schema, source=QuerySource.SQL_LAB)
|
||||
# Sharing a single connection and cursor across the
|
||||
# execution of all statements (if many)
|
||||
annotations: List[SQLValidationAnnotation] = []
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
for statement in parsed_query.get_statements():
|
||||
annotation = cls.validate_statement(
|
||||
statement, database, cursor, user_name
|
||||
)
|
||||
annotation = cls.validate_statement(statement, database, cursor)
|
||||
if annotation:
|
||||
annotations.append(annotation)
|
||||
logger.debug("Validation found %i error(s)", len(annotations))
|
||||
|
||||
Reference in New Issue
Block a user