mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat: use sqlglot to validate adhoc subquery (#33560)
This commit is contained in:
@@ -15,15 +15,17 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=consider-using-transaction
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from sys import getsizeof
|
||||
from typing import Any, cast, Optional, TypeVar, Union
|
||||
from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import backoff
|
||||
import msgpack
|
||||
@@ -56,10 +58,9 @@ from superset.exceptions import (
|
||||
SupersetResultsBackendNotConfigureException,
|
||||
)
|
||||
from superset.extensions import celery_app, event_logger
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql.parse import BaseSQLStatement, CTASMethod, RLSMethod, SQLScript, Table
|
||||
from superset.sql.parse import BaseSQLStatement, CTASMethod, SQLScript, Table
|
||||
from superset.sqllab.limiting_factor import LimitingFactor
|
||||
from superset.sqllab.utils import write_ipc_buffer
|
||||
from superset.utils import json
|
||||
@@ -71,6 +72,9 @@ from superset.utils.core import (
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import stats_timing
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
|
||||
config = app.config
|
||||
stats_logger = config["STATS_LOGGER"]
|
||||
SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
|
||||
@@ -197,52 +201,47 @@ def get_sql_results( # pylint: disable=too-many-arguments
|
||||
return handle_query_error(ex, query)
|
||||
|
||||
|
||||
def apply_rls(query: Query, parsed_statement: BaseSQLStatement[Any]) -> None:
|
||||
def apply_rls(
|
||||
database: Database,
|
||||
catalog: str | None,
|
||||
schema: str,
|
||||
parsed_statement: BaseSQLStatement[Any],
|
||||
) -> None:
|
||||
"""
|
||||
Modify statement inplace to ensure RLS rules are applied.
|
||||
"""
|
||||
database = query.database
|
||||
|
||||
# we need the default schema to fully qualify the table names
|
||||
default_schema = database.get_default_schema_for_query(query)
|
||||
|
||||
# There are two ways to insert RLS: either replacing the table with a subquery
|
||||
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
|
||||
# safer, but not supported in all databases.
|
||||
method = (
|
||||
RLSMethod.AS_SUBQUERY
|
||||
if database.db_engine_spec.allows_subqueries
|
||||
and database.db_engine_spec.allows_alias_in_select
|
||||
else RLSMethod.AS_PREDICATE
|
||||
)
|
||||
method = database.db_engine_spec.get_rls_method()
|
||||
|
||||
# collect all RLS predicates for all tables in the query
|
||||
predicates: dict[Table, list[Any]] = defaultdict(list)
|
||||
predicates: dict[Table, list[Any]] = {}
|
||||
for table in parsed_statement.tables:
|
||||
# fully qualify table
|
||||
table = Table(
|
||||
table.table,
|
||||
table.schema or default_schema,
|
||||
table.catalog or query.catalog,
|
||||
table.schema or schema,
|
||||
table.catalog or catalog,
|
||||
)
|
||||
|
||||
if table_predicates := get_predicates_for_table(
|
||||
table,
|
||||
database,
|
||||
query.catalog == database.get_default_catalog(),
|
||||
):
|
||||
predicates[table].extend(
|
||||
parsed_statement.parse_predicate(predicate)
|
||||
for predicate in table_predicates
|
||||
predicates[table] = [
|
||||
parsed_statement.parse_predicate(predicate)
|
||||
for predicate in get_predicates_for_table(
|
||||
table,
|
||||
database,
|
||||
database.get_default_catalog(),
|
||||
)
|
||||
if predicate
|
||||
]
|
||||
|
||||
parsed_statement.apply_rls(query.catalog, default_schema, predicates, method)
|
||||
parsed_statement.apply_rls(catalog, schema, predicates, method)
|
||||
|
||||
|
||||
def get_predicates_for_table(
|
||||
table: Table,
|
||||
database: Database,
|
||||
is_default_catalog: bool,
|
||||
default_catalog: str | None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the RLS predicates for a table.
|
||||
@@ -254,7 +253,7 @@ def get_predicates_for_table(
|
||||
# if the dataset in the RLS has null catalog, match it when using the default
|
||||
# catalog
|
||||
catalog_predicate = SqlaTable.catalog == table.catalog
|
||||
if table.catalog and is_default_catalog:
|
||||
if table.catalog and table.catalog == default_catalog:
|
||||
catalog_predicate = or_(
|
||||
catalog_predicate,
|
||||
SqlaTable.catalog.is_(None),
|
||||
@@ -483,8 +482,9 @@ def execute_sql_statements( # noqa: C901
|
||||
raise SupersetDMLNotAllowedException()
|
||||
|
||||
if is_feature_enabled("RLS_IN_SQLLAB"):
|
||||
default_schema = query.database.get_default_schema_for_query(query)
|
||||
for statement in parsed_script.statements:
|
||||
apply_rls(query, statement)
|
||||
apply_rls(query.database, query.catalog, default_schema, statement)
|
||||
|
||||
if query.select_as_cta:
|
||||
# CTAS is valid when the last statement is a SELECT, while CVAS is valid when
|
||||
|
||||
Reference in New Issue
Block a user