feat: improve adhoc SQL validation (#19454)

* feat: improve adhoc SQL validation

* Small changes

* Add more unit tests
This commit is contained in:
Beto Dealmeida
2022-03-31 11:55:19 -07:00
committed by GitHub
parent 1a1322d3d9
commit 6828624f61
4 changed files with 170 additions and 72 deletions

View File

@@ -14,21 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-lines
# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines
import unittest
from typing import Set
from typing import Optional, Set
import pytest
import sqlparse
from pytest_mock import MockerFixture
from sqlalchemy import text
from sqlparse.sql import Identifier, Token, TokenList
from sqlparse.tokens import Name
from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
add_table_name,
get_rls_for_table,
has_table_query,
insert_rls,
matches_table_name,
ParsedQuery,
sanitize_clause,
strip_comments_from_sql,
@@ -1391,13 +1394,37 @@ def test_has_table_query(sql: str, expected: bool) -> None:
),
],
)
def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
def test_insert_rls(
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
) -> None:
"""
Insert into a statement a given RLS condition associated with a table.
"""
statement = sqlparse.parse(sql)[0]
condition = sqlparse.parse(rls)[0]
assert str(insert_rls(statement, table, condition)).strip() == expected.strip()
add_table_name(condition, table)
# pylint: disable=unused-argument
def get_rls_for_table(
candidate: Token, database_id: int, default_schema: str
) -> Optional[TokenList]:
"""
Return the RLS ``condition`` if ``candidate`` matches ``table``.
"""
# compare ignoring schema
for left, right in zip(str(candidate).split(".")[::-1], table.split(".")[::-1]):
if left != right:
return None
return condition
mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)
statement = sqlparse.parse(sql)[0]
assert (
str(
insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
).strip()
== expected.strip()
)
@pytest.mark.parametrize(
@@ -1415,16 +1442,29 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None:
assert str(condition) == expected
@pytest.mark.parametrize(
"candidate,table,expected",
[
("table", "table", True),
("schema.table", "table", True),
("table", "schema.table", True),
('schema."my table"', '"my table"', True),
('schema."my.table"', '"my.table"', True),
],
)
def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
token = sqlparse.parse(candidate)[0].tokens[0]
assert matches_table_name(token, table) == expected
def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
"""
Tests for ``get_rls_for_table``.
"""
candidate = Identifier([Token(Name, "some_table")])
db = mocker.patch("superset.db")
dataset = db.session.query().filter().one_or_none()
dataset.__str__.return_value = "some_table"
dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
assert (
str(get_rls_for_table(candidate, 1, "public"))
== "some_table.organization_id = 1"
)
dataset.get_sqla_row_level_filters.return_value = [
text("organization_id = 1"),
text("foo = 'bar'"),
]
assert (
str(get_rls_for_table(candidate, 1, "public"))
== "some_table.organization_id = 1 AND some_table.foo = 'bar'"
)
dataset.get_sqla_row_level_filters.return_value = []
assert get_rls_for_table(candidate, 1, "public") is None