diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c7d9f340aef..24d1243b0a7 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -18,6 +18,7 @@ from collections import OrderedDict from datetime import datetime import logging +import re from typing import Any, Dict, List, NamedTuple, Optional, Union from flask import escape, Markup @@ -1064,9 +1065,36 @@ class SqlaTable(Model, BaseDatasource): def default_query(qry): return qry.filter_by(is_sqllab_view=False) + def has_extra_cache_keys(self, query_obj: Dict) -> bool: + """ + Detects the presence of calls to cache_key_wrapper in items in query_obj that can + be templated. + + :param query_obj: query object to analyze + :return: True if at least one item calls cache_key_wrapper, otherwise False + """ + regex = re.compile(r"\{\{.*cache_key_wrapper\(.*\).*\}\}") + templatable_statements: List[str] = [] + if self.sql: + templatable_statements.append(self.sql) + if self.fetch_values_predicate: + templatable_statements.append(self.fetch_values_predicate) + extras = query_obj.get("extras", {}) + if "where" in extras: + templatable_statements.append(extras["where"]) + if "having" in extras: + templatable_statements.append(extras["having"]) + for statement in templatable_statements: + if regex.search(statement): + return True + return False + def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]: - sqla_query = self.get_sqla_query(**query_obj) - return sqla_query.extra_cache_keys + if self.has_extra_cache_keys(query_obj): + sqla_query = self.get_sqla_query(**query_obj) + extra_cache_keys = sqla_query.extra_cache_keys + return extra_cache_keys + return [] sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm) diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index 8f6212708e1..a46f3880fa9 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -41,7 +41,7 @@ class DatabaseModelTestCase(SupersetTestCase): col = TableColumn(column_name="foo", type="STRING") self.assertEquals(col.is_time, False) - def test_cache_key_wrapper(self): + def test_has_extra_cache_keys(self): query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user" table = SqlaTable(sql=query, database=get_main_database()) query_obj = { @@ -55,4 +55,22 @@ class DatabaseModelTestCase(SupersetTestCase): "extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"}, } extra_cache_keys = table.get_extra_cache_keys(query_obj) + self.assertTrue(table.has_extra_cache_keys(query_obj)) self.assertListEqual(extra_cache_keys, ["user_1", "user_2"]) + + def test_has_no_extra_cache_keys(self): + query = "SELECT 'abc' as user" + table = SqlaTable(sql=query, database=get_main_database()) + query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": ["user"], + "metrics": [], + "is_timeseries": False, + "filter": [], + "extras": {"where": "(user != 'abc')"}, + } + extra_cache_keys = table.get_extra_cache_keys(query_obj) + self.assertFalse(table.has_extra_cache_keys(query_obj)) + self.assertListEqual(extra_cache_keys, [])