feat: improve adhoc SQL validation (#19454)

* feat: improve adhoc SQL validation

* Small changes

* Add more unit tests
This commit is contained in:
Beto Dealmeida
2022-03-31 11:55:19 -07:00
committed by GitHub
parent 1a1322d3d9
commit 6828624f61
4 changed files with 170 additions and 72 deletions

View File

@@ -18,10 +18,11 @@ import logging
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set, Tuple
from typing import cast, List, Optional, Set, Tuple
from urllib import parse
import sqlparse
from sqlalchemy import and_
from sqlparse.sql import (
Identifier,
IdentifierList,
@@ -283,7 +284,7 @@ class ParsedQuery:
return statements
@staticmethod
def _get_table(tlist: TokenList) -> Optional[Table]:
def get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
@@ -324,7 +325,7 @@ class ParsedQuery:
"""
# exclude subselects
if "(" not in str(token_list):
table = self._get_table(token_list)
table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
@@ -500,7 +501,7 @@ def has_table_query(token_list: TokenList) -> bool:
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# # Recurse into child token list
# Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True
@@ -523,7 +524,7 @@ def has_table_query(token_list: TokenList) -> bool:
def add_table_name(rls: TokenList, table: str) -> None:
"""
Modify a RLS expression ensuring columns are fully qualified.
Modify a RLS expression inplace ensuring columns are fully qualified.
"""
tokens = rls.tokens[:]
while tokens:
@@ -539,45 +540,70 @@ def add_table_name(rls: TokenList, table: str) -> None:
tokens.extend(token.tokens)
def matches_table_name(candidate: Token, table: str) -> bool:
def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: Optional[str],
) -> Optional[TokenList]:
"""
Returns if the token represents a reference to the table.
Tables can be fully qualified with periods.
Note that in theory a table should be represented as an identifier, but due to
sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
classified as a keyword.
Given a table name, return any associated RLS predicates.
"""
# pylint: disable=import-outside-toplevel
from superset import db
from superset.connectors.sqla.models import SqlaTable
if not isinstance(candidate, Identifier):
candidate = Identifier([Token(Name, candidate.value)])
target = sqlparse.parse(table)[0].tokens[0]
if not isinstance(target, Identifier):
target = Identifier([Token(Name, target.value)])
table = ParsedQuery.get_table(candidate)
if not table:
return None
# match from right to left, splitting on the period, eg, schema.table == table
for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
if left.value != right.value:
return False
dataset = (
db.session.query(SqlaTable)
.filter(
and_(
SqlaTable.database_id == database_id,
SqlaTable.schema == (table.schema or default_schema),
SqlaTable.table_name == table.table,
)
)
.one_or_none()
)
if not dataset:
return None
return True
template_processor = dataset.get_template_processor()
# pylint: disable=protected-access
predicate = " AND ".join(
str(filter_)
for filter_ in dataset.get_sqla_row_level_filters(template_processor)
)
if not predicate:
return None
rls = sqlparse.parse(predicate)[0]
add_table_name(rls, str(dataset))
return rls
def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
def insert_rls(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
) -> TokenList:
"""
Update a statement inplace applying an RLS associated with a given table.
Update a statement inplace applying any associated RLS predicates.
"""
# make sure the identifier has the table name
add_table_name(rls, table)
rls: Optional[TokenList] = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls(token, table, rls)
token_list.tokens[i] = insert_rls(token, database_id, default_schema)
# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
@@ -587,12 +613,14 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
if matches_table_name(token, table):
rls = get_rls_for_table(token, database_id, default_schema)
if rls:
state = InsertRLSState.FOUND_TABLE
# Found WHERE clause, insert RLS. Note that we insert it even it already exists,
# to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
rls = cast(TokenList, rls)
token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
token.tokens.extend(
[