Add check for calls to cache_key_wrapper (#8128)

* Add check for calls to cache_key_wrapper to avoid unavoidable compilation of query

* Add fetch_values_predicate to check

* Only check relevant attributes

* Address nit
This commit is contained in:
Ville Brofeldt
2019-08-27 23:36:05 +03:00
committed by Grace Guo
parent 6dc760a054
commit 1982b74af2
2 changed files with 49 additions and 3 deletions

View File

@@ -18,6 +18,7 @@
from collections import OrderedDict
from datetime import datetime
import logging
import re
from typing import Any, Dict, List, NamedTuple, Optional, Union
from flask import escape, Markup
@@ -1064,9 +1065,36 @@ class SqlaTable(Model, BaseDatasource):
def default_query(qry):
return qry.filter_by(is_sqllab_view=False)
def has_extra_cache_keys(self, query_obj: Dict) -> bool:
"""
Detects the presence of calls to cache_key_wrapper in items in query_obj that can
be templated.
:param query_obj: query object to analyze
:return: True if at least one item calls cache_key_wrapper, otherwise False
"""
regex = re.compile(r"\{\{.*cache_key_wrapper\(.*\).*\}\}")
templatable_statements: List[str] = []
if self.sql:
templatable_statements.append(self.sql)
if self.fetch_values_predicate:
templatable_statements.append(self.fetch_values_predicate)
extras = query_obj.get("extras", {})
if "where" in extras:
templatable_statements.append(extras["where"])
if "having" in extras:
templatable_statements.append(extras["having"])
for statement in templatable_statements:
if regex.search(statement):
return True
return False
def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]:
sqla_query = self.get_sqla_query(**query_obj)
return sqla_query.extra_cache_keys
if self.has_extra_cache_keys(query_obj):
sqla_query = self.get_sqla_query(**query_obj)
extra_cache_keys = sqla_query.extra_cache_keys
return extra_cache_keys
return []
sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)

View File

@@ -41,7 +41,7 @@ class DatabaseModelTestCase(SupersetTestCase):
col = TableColumn(column_name="foo", type="STRING")
self.assertEquals(col.is_time, False)
def test_cache_key_wrapper(self):
def test_has_extra_cache_keys(self):
query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
table = SqlaTable(sql=query, database=get_main_database())
query_obj = {
@@ -55,4 +55,22 @@ class DatabaseModelTestCase(SupersetTestCase):
"extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"},
}
extra_cache_keys = table.get_extra_cache_keys(query_obj)
self.assertTrue(table.has_extra_cache_keys(query_obj))
self.assertListEqual(extra_cache_keys, ["user_1", "user_2"])
def test_has_no_extra_cache_keys(self):
query = "SELECT 'abc' as user"
table = SqlaTable(sql=query, database=get_main_database())
query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["user"],
"metrics": [],
"is_timeseries": False,
"filter": [],
"extras": {"where": "(user != 'abc')"},
}
extra_cache_keys = table.get_extra_cache_keys(query_obj)
self.assertFalse(table.has_extra_cache_keys(query_obj))
self.assertListEqual(extra_cache_keys, [])