fix: adhoc metrics (#30202)

This commit is contained in:
Beto Dealmeida
2024-10-10 16:46:17 -04:00
committed by GitHub
parent ef0ede7c13
commit 0db59b45b8
7 changed files with 80 additions and 45 deletions

View File

@@ -64,6 +64,7 @@ from superset.sql.parse import (
extract_tables_from_statement,
SQLGLOT_DIALECTS,
SQLScript,
SQLStatement,
Table,
)
from superset.utils.backports import StrEnum
@@ -570,46 +571,31 @@ class InsertRLSState(StrEnum):
FOUND_TABLE = "FOUND_TABLE"
def has_table_query(token_list: TokenList) -> bool:
def has_table_query(expression: str, engine: str) -> bool:
"""
Return if a statement has a query reading from a table.
>>> has_table_query(sqlparse.parse("COUNT(*)")[0])
>>> has_table_query("COUNT(*)", "postgresql")
False
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
>>> has_table_query("SELECT * FROM table", "postgresql")
True
Note that queries reading from constant values return false:
>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
>>> has_table_query("SELECT * FROM (SELECT 1)", "postgresql")
False
"""
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Ignore comments
if isinstance(token, sqlparse.sql.Comment):
continue
# Remove trailing semicolon.
expression = expression.strip().rstrip(";")
# Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True
# Wrap the expression in parentheses if it's not already.
if not expression.startswith("("):
expression = f"({expression})"
# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE
# Found identifier/keyword after FROM/JOIN
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword
):
return True
# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING
return False
sql = f"SELECT {expression}"
statement = SQLStatement(sql, engine)
return any(statement.tables)
def add_table_name(rls: TokenList, table: str) -> None: