mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
feat: improve adhoc SQL validation (#19454)
* feat: improve adhoc SQL validation * Small changes * Add more unit tests
This commit is contained in:
@@ -18,10 +18,11 @@ import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Set, Tuple
|
||||
from typing import cast, List, Optional, Set, Tuple
|
||||
from urllib import parse
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import and_
|
||||
from sqlparse.sql import (
|
||||
Identifier,
|
||||
IdentifierList,
|
||||
@@ -283,7 +284,7 @@ class ParsedQuery:
|
||||
return statements
|
||||
|
||||
@staticmethod
|
||||
def _get_table(tlist: TokenList) -> Optional[Table]:
|
||||
def get_table(tlist: TokenList) -> Optional[Table]:
|
||||
"""
|
||||
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
|
||||
construct.
|
||||
@@ -324,7 +325,7 @@ class ParsedQuery:
|
||||
"""
|
||||
# exclude subselects
|
||||
if "(" not in str(token_list):
|
||||
table = self._get_table(token_list)
|
||||
table = self.get_table(token_list)
|
||||
if table and not table.table.startswith(CTE_PREFIX):
|
||||
self._tables.add(table)
|
||||
return
|
||||
@@ -500,7 +501,7 @@ def has_table_query(token_list: TokenList) -> bool:
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
|
||||
# # Recurse into child token list
|
||||
# Recurse into child token list
|
||||
if isinstance(token, TokenList) and has_table_query(token):
|
||||
return True
|
||||
|
||||
@@ -523,7 +524,7 @@ def has_table_query(token_list: TokenList) -> bool:
|
||||
|
||||
def add_table_name(rls: TokenList, table: str) -> None:
|
||||
"""
|
||||
Modify a RLS expression ensuring columns are fully qualified.
|
||||
Modify a RLS expression inplace ensuring columns are fully qualified.
|
||||
"""
|
||||
tokens = rls.tokens[:]
|
||||
while tokens:
|
||||
@@ -539,45 +540,70 @@ def add_table_name(rls: TokenList, table: str) -> None:
|
||||
tokens.extend(token.tokens)
|
||||
|
||||
|
||||
def matches_table_name(candidate: Token, table: str) -> bool:
|
||||
def get_rls_for_table(
|
||||
candidate: Token,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
) -> Optional[TokenList]:
|
||||
"""
|
||||
Returns if the token represents a reference to the table.
|
||||
|
||||
Tables can be fully qualified with periods.
|
||||
|
||||
Note that in theory a table should be represented as an identifier, but due to
|
||||
sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
|
||||
classified as a keyword.
|
||||
Given a table name, return any associated RLS predicates.
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
if not isinstance(candidate, Identifier):
|
||||
candidate = Identifier([Token(Name, candidate.value)])
|
||||
|
||||
target = sqlparse.parse(table)[0].tokens[0]
|
||||
if not isinstance(target, Identifier):
|
||||
target = Identifier([Token(Name, target.value)])
|
||||
table = ParsedQuery.get_table(candidate)
|
||||
if not table:
|
||||
return None
|
||||
|
||||
# match from right to left, splitting on the period, eg, schema.table == table
|
||||
for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
|
||||
if left.value != right.value:
|
||||
return False
|
||||
dataset = (
|
||||
db.session.query(SqlaTable)
|
||||
.filter(
|
||||
and_(
|
||||
SqlaTable.database_id == database_id,
|
||||
SqlaTable.schema == (table.schema or default_schema),
|
||||
SqlaTable.table_name == table.table,
|
||||
)
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
return True
|
||||
template_processor = dataset.get_template_processor()
|
||||
# pylint: disable=protected-access
|
||||
predicate = " AND ".join(
|
||||
str(filter_)
|
||||
for filter_ in dataset.get_sqla_row_level_filters(template_processor)
|
||||
)
|
||||
if not predicate:
|
||||
return None
|
||||
|
||||
rls = sqlparse.parse(predicate)[0]
|
||||
add_table_name(rls, str(dataset))
|
||||
|
||||
return rls
|
||||
|
||||
|
||||
def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
|
||||
def insert_rls(
|
||||
token_list: TokenList,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying an RLS associated with a given table.
|
||||
Update a statement inplace applying any associated RLS predicates.
|
||||
"""
|
||||
# make sure the identifier has the table name
|
||||
add_table_name(rls, table)
|
||||
|
||||
rls: Optional[TokenList] = None
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
|
||||
# Recurse into child token list
|
||||
if isinstance(token, TokenList):
|
||||
i = token_list.tokens.index(token)
|
||||
token_list.tokens[i] = insert_rls(token, table, rls)
|
||||
token_list.tokens[i] = insert_rls(token, database_id, default_schema)
|
||||
|
||||
# Found a source keyword (FROM/JOIN)
|
||||
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
|
||||
@@ -587,12 +613,14 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
|
||||
elif state == InsertRLSState.SEEN_SOURCE and (
|
||||
isinstance(token, Identifier) or token.ttype == Keyword
|
||||
):
|
||||
if matches_table_name(token, table):
|
||||
rls = get_rls_for_table(token, database_id, default_schema)
|
||||
if rls:
|
||||
state = InsertRLSState.FOUND_TABLE
|
||||
|
||||
# Found WHERE clause, insert RLS. Note that we insert it even it already exists,
|
||||
# to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
|
||||
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
|
||||
rls = cast(TokenList, rls)
|
||||
token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
|
||||
token.tokens.extend(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user