feat: use sqlglot to validate adhoc subquery (#33560)

This commit is contained in:
Beto Dealmeida
2025-05-30 18:09:19 -04:00
committed by GitHub
parent cf315388f2
commit 401ce56fa1
10 changed files with 123 additions and 92 deletions

View File

@@ -15,15 +15,17 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=consider-using-transaction
from __future__ import annotations
import dataclasses
import logging
import sys
import uuid
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from sys import getsizeof
from typing import Any, cast, Optional, TypeVar, Union
from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union
import backoff
import msgpack
@@ -56,10 +58,9 @@ from superset.exceptions import (
SupersetResultsBackendNotConfigureException,
)
from superset.extensions import celery_app, event_logger
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql.parse import BaseSQLStatement, CTASMethod, RLSMethod, SQLScript, Table
from superset.sql.parse import BaseSQLStatement, CTASMethod, SQLScript, Table
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import write_ipc_buffer
from superset.utils import json
@@ -71,6 +72,9 @@ from superset.utils.core import (
from superset.utils.dates import now_as_float
from superset.utils.decorators import stats_timing
if TYPE_CHECKING:
from superset.models.core import Database
config = app.config
stats_logger = config["STATS_LOGGER"]
SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
@@ -197,52 +201,47 @@ def get_sql_results( # pylint: disable=too-many-arguments
return handle_query_error(ex, query)
def apply_rls(query: Query, parsed_statement: BaseSQLStatement[Any]) -> None:
def apply_rls(
database: Database,
catalog: str | None,
schema: str,
parsed_statement: BaseSQLStatement[Any],
) -> None:
"""
Modify statement inplace to ensure RLS rules are applied.
"""
database = query.database
# we need the default schema to fully qualify the table names
default_schema = database.get_default_schema_for_query(query)
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
# safer, but not supported in all databases.
method = (
RLSMethod.AS_SUBQUERY
if database.db_engine_spec.allows_subqueries
and database.db_engine_spec.allows_alias_in_select
else RLSMethod.AS_PREDICATE
)
method = database.db_engine_spec.get_rls_method()
# collect all RLS predicates for all tables in the query
predicates: dict[Table, list[Any]] = defaultdict(list)
predicates: dict[Table, list[Any]] = {}
for table in parsed_statement.tables:
# fully qualify table
table = Table(
table.table,
table.schema or default_schema,
table.catalog or query.catalog,
table.schema or schema,
table.catalog or catalog,
)
if table_predicates := get_predicates_for_table(
table,
database,
query.catalog == database.get_default_catalog(),
):
predicates[table].extend(
parsed_statement.parse_predicate(predicate)
for predicate in table_predicates
predicates[table] = [
parsed_statement.parse_predicate(predicate)
for predicate in get_predicates_for_table(
table,
database,
database.get_default_catalog(),
)
if predicate
]
parsed_statement.apply_rls(query.catalog, default_schema, predicates, method)
parsed_statement.apply_rls(catalog, schema, predicates, method)
def get_predicates_for_table(
table: Table,
database: Database,
is_default_catalog: bool,
default_catalog: str | None,
) -> list[str]:
"""
Get the RLS predicates for a table.
@@ -254,7 +253,7 @@ def get_predicates_for_table(
# if the dataset in the RLS has null catalog, match it when using the default
# catalog
catalog_predicate = SqlaTable.catalog == table.catalog
if table.catalog and is_default_catalog:
if table.catalog and table.catalog == default_catalog:
catalog_predicate = or_(
catalog_predicate,
SqlaTable.catalog.is_(None),
@@ -483,8 +482,9 @@ def execute_sql_statements( # noqa: C901
raise SupersetDMLNotAllowedException()
if is_feature_enabled("RLS_IN_SQLLAB"):
default_schema = query.database.get_default_schema_for_query(query)
for statement in parsed_script.statements:
apply_rls(query, statement)
apply_rls(query.database, query.catalog, default_schema, statement)
if query.select_as_cta:
# CTAS is valid when the last statement is a SELECT, while CVAS is valid when