mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
Add check for calls to cache_key_wrapper (#8128)
* Add check for calls to cache_key_wrapper to avoid unavoidable compilation of query * Add fetch_values_predicate to check * Only check relevant attributes * Address nit
This commit is contained in:
committed by
Grace Guo
parent
6dc760a054
commit
1982b74af2
@@ -18,6 +18,7 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
from flask import escape, Markup
|
||||
@@ -1064,9 +1065,36 @@ class SqlaTable(Model, BaseDatasource):
|
||||
def default_query(qry):
|
||||
return qry.filter_by(is_sqllab_view=False)
|
||||
|
||||
def has_extra_cache_keys(self, query_obj: Dict) -> bool:
|
||||
"""
|
||||
Detects the presence of calls to cache_key_wrapper in items in query_obj that can
|
||||
be templated.
|
||||
|
||||
:param query_obj: query object to analyze
|
||||
:return: True if at least one item calls cache_key_wrapper, otherwise False
|
||||
"""
|
||||
regex = re.compile(r"\{\{.*cache_key_wrapper\(.*\).*\}\}")
|
||||
templatable_statements: List[str] = []
|
||||
if self.sql:
|
||||
templatable_statements.append(self.sql)
|
||||
if self.fetch_values_predicate:
|
||||
templatable_statements.append(self.fetch_values_predicate)
|
||||
extras = query_obj.get("extras", {})
|
||||
if "where" in extras:
|
||||
templatable_statements.append(extras["where"])
|
||||
if "having" in extras:
|
||||
templatable_statements.append(extras["having"])
|
||||
for statement in templatable_statements:
|
||||
if regex.search(statement):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]:
|
||||
sqla_query = self.get_sqla_query(**query_obj)
|
||||
return sqla_query.extra_cache_keys
|
||||
if self.has_extra_cache_keys(query_obj):
|
||||
sqla_query = self.get_sqla_query(**query_obj)
|
||||
extra_cache_keys = sqla_query.extra_cache_keys
|
||||
return extra_cache_keys
|
||||
return []
|
||||
|
||||
|
||||
sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)
|
||||
|
||||
@@ -41,7 +41,7 @@ class DatabaseModelTestCase(SupersetTestCase):
|
||||
col = TableColumn(column_name="foo", type="STRING")
|
||||
self.assertEquals(col.is_time, False)
|
||||
|
||||
def test_cache_key_wrapper(self):
|
||||
def test_has_extra_cache_keys(self):
|
||||
query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
|
||||
table = SqlaTable(sql=query, database=get_main_database())
|
||||
query_obj = {
|
||||
@@ -55,4 +55,22 @@ class DatabaseModelTestCase(SupersetTestCase):
|
||||
"extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"},
|
||||
}
|
||||
extra_cache_keys = table.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table.has_extra_cache_keys(query_obj))
|
||||
self.assertListEqual(extra_cache_keys, ["user_1", "user_2"])
|
||||
|
||||
def test_has_no_extra_cache_keys(self):
|
||||
query = "SELECT 'abc' as user"
|
||||
table = SqlaTable(sql=query, database=get_main_database())
|
||||
query_obj = {
|
||||
"granularity": None,
|
||||
"from_dttm": None,
|
||||
"to_dttm": None,
|
||||
"groupby": ["user"],
|
||||
"metrics": [],
|
||||
"is_timeseries": False,
|
||||
"filter": [],
|
||||
"extras": {"where": "(user != 'abc')"},
|
||||
}
|
||||
extra_cache_keys = table.get_extra_cache_keys(query_obj)
|
||||
self.assertFalse(table.has_extra_cache_keys(query_obj))
|
||||
self.assertListEqual(extra_cache_keys, [])
|
||||
|
||||
Reference in New Issue
Block a user