mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
fix: search_path in RDS (#24739)
This commit is contained in:
@@ -14,13 +14,17 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from re import Pattern
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import sqlparse
|
||||
from flask_babel import gettext as __
|
||||
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
|
||||
from sqlalchemy.dialects.postgresql.base import PGInspector
|
||||
@@ -30,8 +34,8 @@ from sqlalchemy.types import Date, DateTime, String
|
||||
|
||||
from superset.constants import TimeGrain
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
|
||||
from superset.errors import SupersetErrorType
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetException, SupersetSecurityException
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import GenericDataType
|
||||
@@ -169,9 +173,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def fetch_data(
|
||||
cls, cursor: Any, limit: Optional[int] = None
|
||||
) -> list[tuple[Any, ...]]:
|
||||
def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
|
||||
if not cursor.description:
|
||||
return []
|
||||
return super().fetch_data(cursor, limit)
|
||||
@@ -224,7 +226,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
|
||||
cls,
|
||||
sqlalchemy_uri: URL,
|
||||
connect_args: dict[str, Any],
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""
|
||||
Return the configured schema.
|
||||
|
||||
@@ -237,6 +239,9 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
|
||||
in Superset because it breaks schema-level permissions, since it's impossible
|
||||
to determine the schema for a non-qualified table in a query. In cases like
|
||||
that we raise an exception.
|
||||
|
||||
Note that because the DB engine supports dynamic schema this method is never
|
||||
called. It's left here as an implementation reference.
|
||||
"""
|
||||
options = parse_options(connect_args)
|
||||
if search_path := options.get("search_path"):
|
||||
@@ -252,23 +257,50 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def adjust_engine_params(
|
||||
def get_default_schema_for_query(
|
||||
cls,
|
||||
uri: URL,
|
||||
connect_args: dict[str, Any],
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> tuple[URL, dict[str, Any]]:
|
||||
if not schema:
|
||||
return uri, connect_args
|
||||
database: Database,
|
||||
query: Query,
|
||||
) -> str | None:
|
||||
"""
|
||||
Return the default schema for a given query.
|
||||
|
||||
options = parse_options(connect_args)
|
||||
options["search_path"] = schema
|
||||
connect_args["options"] = " ".join(
|
||||
f"-c{key}={value}" for key, value in options.items()
|
||||
)
|
||||
This method simply uses the parent method after checking that there are no
|
||||
malicious path setting in the query.
|
||||
"""
|
||||
sql = sqlparse.format(query.sql, strip_comments=True)
|
||||
if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE):
|
||||
raise SupersetSecurityException(
|
||||
SupersetError(
|
||||
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
|
||||
message=__(
|
||||
"Users are not allowed to set a search path for security reasons."
|
||||
),
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
return uri, connect_args
|
||||
return super().get_default_schema_for_query(database, query)
|
||||
|
||||
@classmethod
|
||||
def get_prequeries(
|
||||
cls,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Set the search path to the specified schema.
|
||||
|
||||
This is important for two reasons: in SQL Lab it will allow queries to run in
|
||||
the schema selected in the dropdown, resolving unqualified table names to the
|
||||
expected schema.
|
||||
|
||||
But more importantly, in SQL Lab this is used to check if the user has access to
|
||||
any tables with unqualified names. If the schema is not set by SQL Lab it could
|
||||
be anything, and we would have to block users from running any queries
|
||||
referencing tables without an explicit schema.
|
||||
"""
|
||||
return [f'set search_path = "{schema}"'] if schema else []
|
||||
|
||||
@classmethod
|
||||
def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
|
||||
@@ -298,7 +330,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
|
||||
@classmethod
|
||||
def get_catalog_names(
|
||||
cls,
|
||||
database: "Database",
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
) -> list[str]:
|
||||
"""
|
||||
@@ -318,7 +350,7 @@ WHERE datistemplate = false;
|
||||
|
||||
@classmethod
|
||||
def get_table_names(
|
||||
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
|
||||
cls, database: Database, inspector: PGInspector, schema: str | None
|
||||
) -> set[str]:
|
||||
"""Need to consider foreign tables for PostgreSQL"""
|
||||
return set(inspector.get_table_names(schema)) | set(
|
||||
@@ -327,8 +359,8 @@ WHERE datistemplate = false;
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(
|
||||
cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
|
||||
) -> str | None:
|
||||
sqla_type = cls.get_sqla_column_type(target_type)
|
||||
|
||||
if isinstance(sqla_type, Date):
|
||||
@@ -339,7 +371,7 @@ WHERE datistemplate = false;
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_extra_params(database: "Database") -> dict[str, Any]:
|
||||
def get_extra_params(database: Database) -> dict[str, Any]:
|
||||
"""
|
||||
For Postgres, the path to a SSL certificate is placed in `connect_args`.
|
||||
|
||||
@@ -363,7 +395,7 @@ WHERE datistemplate = false;
|
||||
return extra
|
||||
|
||||
@classmethod
|
||||
def get_datatype(cls, type_code: Any) -> Optional[str]:
|
||||
def get_datatype(cls, type_code: Any) -> str | None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from psycopg2.extensions import binary_types, string_types
|
||||
|
||||
@@ -374,7 +406,7 @@ WHERE datistemplate = false;
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
|
||||
def get_cancel_query_id(cls, cursor: Any, query: Query) -> str | None:
|
||||
"""
|
||||
Get Postgres PID that will be used to cancel all other running
|
||||
queries in the same session.
|
||||
|
||||
Reference in New Issue
Block a user