fix: search_path in RDS (#24739)

This commit is contained in:
Beto Dealmeida
2023-07-20 12:57:48 -07:00
committed by GitHub
parent b2831b419e
commit 7675e0db10
5 changed files with 185 additions and 85 deletions

View File

@@ -14,13 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import logging
import re
from datetime import datetime
from re import Pattern
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
import sqlparse
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
@@ -30,8 +34,8 @@ from sqlalchemy.types import Date, DateTime, String
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.errors import SupersetErrorType
from superset.exceptions import SupersetException
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.sql_lab import Query
from superset.utils import core as utils
from superset.utils.core import GenericDataType
@@ -169,9 +173,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
}
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
) -> list[tuple[Any, ...]]:
def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
if not cursor.description:
return []
return super().fetch_data(cursor, limit)
@@ -224,7 +226,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
cls,
sqlalchemy_uri: URL,
connect_args: dict[str, Any],
) -> Optional[str]:
) -> str | None:
"""
Return the configured schema.
@@ -237,6 +239,9 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
in Superset because it breaks schema-level permissions, since it's impossible
to determine the schema for a non-qualified table in a query. In cases like
that we raise an exception.
Note that because the DB engine supports dynamic schema this method is never
called. It's left here as an implementation reference.
"""
options = parse_options(connect_args)
if search_path := options.get("search_path"):
@@ -252,23 +257,50 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
return None
@classmethod
def adjust_engine_params(
def get_default_schema_for_query(
cls,
uri: URL,
connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> tuple[URL, dict[str, Any]]:
if not schema:
return uri, connect_args
database: Database,
query: Query,
) -> str | None:
"""
Return the default schema for a given query.
options = parse_options(connect_args)
options["search_path"] = schema
connect_args["options"] = " ".join(
f"-c{key}={value}" for key, value in options.items()
)
This method simply uses the parent method after checking that there are no
malicious path setting in the query.
"""
sql = sqlparse.format(query.sql, strip_comments=True)
if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
message=__(
"Users are not allowed to set a search path for security reasons."
),
level=ErrorLevel.ERROR,
)
)
return uri, connect_args
return super().get_default_schema_for_query(database, query)
@classmethod
def get_prequeries(
cls,
catalog: str | None = None,
schema: str | None = None,
) -> list[str]:
"""
Set the search path to the specified schema.
This is important for two reasons: in SQL Lab it will allow queries to run in
the schema selected in the dropdown, resolving unqualified table names to the
expected schema.
But more importantly, in SQL Lab this is used to check if the user has access to
any tables with unqualified names. If the schema is not set by SQL Lab it could
be anything, and we would have to block users from running any queries
referencing tables without an explicit schema.
"""
return [f'set search_path = "{schema}"'] if schema else []
@classmethod
def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
@@ -298,7 +330,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
@classmethod
def get_catalog_names(
cls,
database: "Database",
database: Database,
inspector: Inspector,
) -> list[str]:
"""
@@ -318,7 +350,7 @@ WHERE datistemplate = false;
@classmethod
def get_table_names(
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
cls, database: Database, inspector: PGInspector, schema: str | None
) -> set[str]:
"""Need to consider foreign tables for PostgreSQL"""
return set(inspector.get_table_names(schema)) | set(
@@ -327,8 +359,8 @@ WHERE datistemplate = false;
@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, Date):
@@ -339,7 +371,7 @@ WHERE datistemplate = false;
return None
@staticmethod
def get_extra_params(database: "Database") -> dict[str, Any]:
def get_extra_params(database: Database) -> dict[str, Any]:
"""
For Postgres, the path to a SSL certificate is placed in `connect_args`.
@@ -363,7 +395,7 @@ WHERE datistemplate = false;
return extra
@classmethod
def get_datatype(cls, type_code: Any) -> Optional[str]:
def get_datatype(cls, type_code: Any) -> str | None:
# pylint: disable=import-outside-toplevel
from psycopg2.extensions import binary_types, string_types
@@ -374,7 +406,7 @@ WHERE datistemplate = false;
return None
@classmethod
def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
def get_cancel_query_id(cls, cursor: Any, query: Query) -> str | None:
"""
Get Postgres PID that will be used to cancel all other running
queries in the same session.