diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index cb40b954081..2056109bbff 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -58,7 +58,9 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods result_type = result_type or ChartDataResultType.FULL result_format = result_format or ChartDataResultFormat.JSON queries_ = [ - self._query_object_factory.create(result_type, **query_obj) + self._query_object_factory.create( + result_type, datasource=datasource, **query_obj + ) for query_obj in queries ] cache_values = { diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 40d37041b91..a8585fd47e0 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -23,9 +23,11 @@ from datetime import datetime, timedelta from pprint import pformat from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING +from flask import g from flask_babel import gettext as _ from pandas import DataFrame +from superset import feature_flag_manager from superset.common.chart_data import ChartDataResultType from superset.exceptions import ( InvalidPostProcessingError, @@ -396,6 +398,24 @@ class QueryObject: # pylint: disable=too-many-instance-attributes if annotation_layers: cache_dict["annotation_layers"] = annotation_layers + # Add an impersonation key to cache if impersonation is enabled on the db + if ( + feature_flag_manager.is_feature_enabled("CACHE_IMPERSONATION") + and self.datasource + and hasattr(self.datasource, "database") + and self.datasource.database.impersonate_user + ): + + if key := self.datasource.database.db_engine_spec.get_impersonation_key( + getattr(g, "user", None) + ): + + logger.debug( + "Adding impersonation key to QueryObject cache dict: %s", key + ) + + cache_dict["impersonation_key"] = key + return md5_sha_from_dict(cache_dict, default=json_int_dttm_ser, ignore_nan=True) def exec_post_processing(self, df: DataFrame) -> DataFrame: diff --git a/superset/config.py b/superset/config.py index 8a5ec248fb8..17c6a55412d 100644 --- a/superset/config.py +++ b/superset/config.py @@ -429,6 +429,9 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = { # Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the # query, and might break queries and/or allow users to bypass RLS. Use with care! "RLS_IN_SQLLAB": False, + # Enable caching per impersonation key (e.g username) in a datasource where user + # impersonation is enabled + "CACHE_IMPERSONATION": False, } # Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars. diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 6a2ddc5e5c3..b4f4ec25c45 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -41,6 +41,7 @@ import sqlparse from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from flask import current_app +from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __, lazy_gettext as _ from marshmallow import fields, Schema from marshmallow.validate import Range @@ -1537,6 +1538,17 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods def parse_sql(cls, sql: str) -> List[str]: return [str(s).strip(" ;") for s in sqlparse.parse(sql)] + @classmethod + def get_impersonation_key(cls, user: Optional[User]) -> Any: + """ + Construct an impersonation key, by default it's the given username. + + :param user: logged in user + + :returns: username if given user is not null + """ + return user.username if user else None + # schema for adding a database by providing parameters instead of the # full SQLAlchemy URI