fix: Refactor SQL username logic (#19914)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley
2022-05-12 21:03:05 -07:00
committed by GitHub
parent fff9ad05d4
commit 449d08b25e
22 changed files with 388 additions and 340 deletions

View File

@@ -50,7 +50,13 @@ from superset.result_set import SupersetResultSet
from superset.sql_parse import CtasMethod, insert_rls, ParsedQuery
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.celery import session_scope
from superset.utils.core import json_iso_dttm_ser, QuerySource, zlib_compress
from superset.utils.core import (
get_username,
json_iso_dttm_ser,
override_user,
QuerySource,
zlib_compress,
)
from superset.utils.dates import now_as_float
from superset.utils.decorators import stats_timing
@@ -155,37 +161,35 @@ def get_sql_results( # pylint: disable=too-many-arguments
rendered_query: str,
return_results: bool = True,
store_results: bool = False,
user_name: Optional[str] = None,
username: Optional[str] = None,
start_time: Optional[float] = None,
expand_data: bool = False,
log_params: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""Executes the sql query returns the results."""
with session_scope(not ctask.request.called_directly) as session:
try:
return execute_sql_statements(
query_id,
rendered_query,
return_results,
store_results,
user_name,
session=session,
start_time=start_time,
expand_data=expand_data,
log_params=log_params,
)
except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled")
query = get_query(query_id, session)
return handle_query_error(ex, query, session)
with override_user(security_manager.find_user(username)):
try:
return execute_sql_statements(
query_id,
rendered_query,
return_results,
store_results,
session=session,
start_time=start_time,
expand_data=expand_data,
log_params=log_params,
)
except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled")
query = get_query(query_id, session)
return handle_query_error(ex, query, session)
def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements
sql_statement: str,
query: Query,
user_name: Optional[str],
session: Session,
cursor: Any,
log_params: Optional[Dict[str, Any]],
@@ -204,7 +208,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
parsed_query._parsed[0], # pylint: disable=protected-access
database.id,
query.schema,
username=user_name,
username=get_username(),
)
)
)
@@ -246,7 +250,10 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
# Hook to allow environment-specific mutation (usually comments) to the SQL
sql = SQL_QUERY_MUTATOR(
sql, user_name=user_name, security_manager=security_manager, database=database
sql,
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
security_manager=security_manager,
database=database,
)
try:
query.executed_sql = sql
@@ -255,7 +262,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals
query.database.sqlalchemy_uri,
query.executed_sql,
query.schema,
user_name,
get_username(),
__name__,
security_manager,
log_params,
@@ -375,7 +382,6 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
rendered_query: str,
return_results: bool,
store_results: bool,
user_name: Optional[str],
session: Session,
start_time: Optional[float],
expand_data: bool,
@@ -452,12 +458,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
)
)
engine = database.get_sqla_engine(
schema=query.schema,
nullpool=True,
user_name=user_name,
source=QuerySource.SQL_LAB,
)
engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
# Sharing a single connection and cursor across the
# execution of all statements (if many)
with closing(engine.raw_connection()) as conn:
@@ -490,7 +491,6 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
result_set = execute_sql_statement(
statement,
query,
user_name,
session,
cursor,
log_params,
@@ -597,7 +597,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
return None
def cancel_query(query: Query, user_name: Optional[str] = None) -> bool:
def cancel_query(query: Query) -> bool:
"""
Cancel a running query.
@@ -605,7 +605,6 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> bool:
action is required.
:param query: Query to cancel
:param user_name: Default username
:return: True if query cancelled successfully, False otherwise
"""
@@ -616,12 +615,7 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> bool:
if cancel_query_id is None:
return False
engine = query.database.get_sqla_engine(
schema=query.schema,
nullpool=True,
user_name=user_name,
source=QuerySource.SQL_LAB,
)
engine = query.database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor: