mirror of
https://github.com/apache/superset.git
synced 2026-04-12 12:47:53 +00:00
feat: improve adhoc SQL validation (#19454)
* feat: improve adhoc SQL validation * Small changes * Add more unit tests
This commit is contained in:
@@ -14,21 +14,24 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=invalid-name, too-many-lines
|
||||
# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines
|
||||
|
||||
import unittest
|
||||
from typing import Set
|
||||
from typing import Optional, Set
|
||||
|
||||
import pytest
|
||||
import sqlparse
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import text
|
||||
from sqlparse.sql import Identifier, Token, TokenList
|
||||
from sqlparse.tokens import Name
|
||||
|
||||
from superset.exceptions import QueryClauseValidationException
|
||||
from superset.sql_parse import (
|
||||
add_table_name,
|
||||
get_rls_for_table,
|
||||
has_table_query,
|
||||
insert_rls,
|
||||
matches_table_name,
|
||||
ParsedQuery,
|
||||
sanitize_clause,
|
||||
strip_comments_from_sql,
|
||||
@@ -1391,13 +1394,37 @@ def test_has_table_query(sql: str, expected: bool) -> None:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
|
||||
def test_insert_rls(
|
||||
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
|
||||
) -> None:
|
||||
"""
|
||||
Insert into a statement a given RLS condition associated with a table.
|
||||
"""
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
condition = sqlparse.parse(rls)[0]
|
||||
assert str(insert_rls(statement, table, condition)).strip() == expected.strip()
|
||||
add_table_name(condition, table)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get_rls_for_table(
|
||||
candidate: Token, database_id: int, default_schema: str
|
||||
) -> Optional[TokenList]:
|
||||
"""
|
||||
Return the RLS ``condition`` if ``candidate`` matches ``table``.
|
||||
"""
|
||||
# compare ignoring schema
|
||||
for left, right in zip(str(candidate).split(".")[::-1], table.split(".")[::-1]):
|
||||
if left != right:
|
||||
return None
|
||||
return condition
|
||||
|
||||
mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)
|
||||
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
assert (
|
||||
str(
|
||||
insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
|
||||
).strip()
|
||||
== expected.strip()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -1415,16 +1442,29 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None:
|
||||
assert str(condition) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"candidate,table,expected",
|
||||
[
|
||||
("table", "table", True),
|
||||
("schema.table", "table", True),
|
||||
("table", "schema.table", True),
|
||||
('schema."my table"', '"my table"', True),
|
||||
('schema."my.table"', '"my.table"', True),
|
||||
],
|
||||
)
|
||||
def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
|
||||
token = sqlparse.parse(candidate)[0].tokens[0]
|
||||
assert matches_table_name(token, table) == expected
|
||||
def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
|
||||
"""
|
||||
Tests for ``get_rls_for_table``.
|
||||
"""
|
||||
candidate = Identifier([Token(Name, "some_table")])
|
||||
db = mocker.patch("superset.db")
|
||||
dataset = db.session.query().filter().one_or_none()
|
||||
dataset.__str__.return_value = "some_table"
|
||||
|
||||
dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
|
||||
assert (
|
||||
str(get_rls_for_table(candidate, 1, "public"))
|
||||
== "some_table.organization_id = 1"
|
||||
)
|
||||
|
||||
dataset.get_sqla_row_level_filters.return_value = [
|
||||
text("organization_id = 1"),
|
||||
text("foo = 'bar'"),
|
||||
]
|
||||
assert (
|
||||
str(get_rls_for_table(candidate, 1, "public"))
|
||||
== "some_table.organization_id = 1 AND some_table.foo = 'bar'"
|
||||
)
|
||||
|
||||
dataset.get_sqla_row_level_filters.return_value = []
|
||||
assert get_rls_for_table(candidate, 1, "public") is None
|
||||
|
||||
Reference in New Issue
Block a user