Files
superset2/superset/daos/database.py
2025-12-03 14:26:35 -05:00

258 lines
8.4 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import Any
from sqlalchemy.orm import joinedload
from superset import is_feature_enabled
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelingNotEnabledError
from superset.connectors.sqla.models import SqlaTable
from superset.daos.base import BaseDAO
from superset.databases.filters import DatabaseFilter
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.extensions import db
from superset.models.core import Database, DatabaseUserOAuth2Tokens
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import TabState
from superset.utils.core import DatasourceType
from superset.utils.ssh_tunnel import unmask_password_info
logger = logging.getLogger(__name__)
class DatabaseDAO(BaseDAO[Database]):
base_filter = DatabaseFilter
@classmethod
def create(
cls,
item: Database | None = None,
attributes: dict[str, Any] | None = None,
) -> Database:
"""
Create a new database, with an optional SSH tunnel.
"""
ssh_tunnel_attributes = (
attributes.pop("ssh_tunnel", None) if attributes else None
)
database = super().create(item, attributes)
if ssh_tunnel_attributes:
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
database.ssh_tunnel = SSHTunnel(**ssh_tunnel_attributes)
return database
@classmethod
def find_by_id(
cls,
model_id: str | int,
skip_base_filter: bool = False,
id_column: str | None = None,
) -> Database | None:
"""
Find a database by id, eagerly loading the SSH tunnel relationship.
"""
query = db.session.query(cls.model_cls).options(joinedload(Database.ssh_tunnel))
query = cls._apply_base_filter(query, skip_base_filter)
column_name = id_column or cls.id_column_name
if not hasattr(cls.model_cls, column_name):
raise AttributeError(
"{0} has no column {1}".format(cls.model_cls, column_name)
)
column = getattr(cls.model_cls, column_name)
return query.filter(column == model_id).one_or_none()
@classmethod
def update(
cls,
item: Database | None = None,
attributes: dict[str, Any] | None = None,
) -> Database:
"""
Unmask ``encrypted_extra`` before updating.
When a database is edited the user sees a masked version of ``encrypted_extra``,
depending on the engine spec. Eg, BigQuery will mask the ``private_key`` attribute
of the credentials.
The masked values should be unmasked before the database is updated.
""" # noqa: E501
attributes = attributes or {}
if item and "encrypted_extra" in attributes:
attributes["encrypted_extra"] = item.db_engine_spec.unmask_encrypted_extra(
item.encrypted_extra,
attributes["encrypted_extra"],
)
# update SSH tunnel
if "ssh_tunnel" not in attributes:
# keep existing tunnel if it exists
ssh_tunnel = item.ssh_tunnel if item else None
elif attributes["ssh_tunnel"] is None:
# remove existing tunnel
ssh_tunnel = None
else:
# update existing tunnel or create a new one
ssh_tunnel_attributes = attributes.pop("ssh_tunnel")
# when updating the tunnel, passwords are sent masked to the frontend; if
# they arrive back masked we need to unmask them
if item and item.ssh_tunnel:
ssh_tunnel_attributes = unmask_password_info(
ssh_tunnel_attributes,
item.ssh_tunnel,
)
# delete existing SSH tunnel first
item.ssh_tunnel = None
db.session.flush()
ssh_tunnel = SSHTunnel(**ssh_tunnel_attributes)
database = super().update(item, attributes)
database.ssh_tunnel = ssh_tunnel
return database
@staticmethod
def validate_uniqueness(database_name: str) -> bool:
database_query = db.session.query(Database).filter(
Database.database_name == database_name
)
return not db.session.query(database_query.exists()).scalar()
@staticmethod
def validate_update_uniqueness(database_id: int, database_name: str) -> bool:
database_query = db.session.query(Database).filter(
Database.database_name == database_name,
Database.id != database_id,
)
return not db.session.query(database_query.exists()).scalar()
@staticmethod
def get_database_by_name(database_name: str) -> Database | None:
return (
db.session.query(Database)
.filter(Database.database_name == database_name)
.one_or_none()
)
@staticmethod
def build_db_for_connection_test(
server_cert: str,
extra: str,
impersonate_user: bool,
encrypted_extra: str,
ssh_tunnel: dict[str, Any] | None = None,
) -> Database:
return Database(
server_cert=server_cert,
extra=extra,
impersonate_user=impersonate_user,
encrypted_extra=encrypted_extra,
ssh_tunnel=SSHTunnel(**ssh_tunnel) if ssh_tunnel else None,
)
@classmethod
def get_related_objects(cls, database_id: int) -> dict[str, Any]:
database: Any = cls.find_by_id(database_id)
datasets = database.tables
dataset_ids = [dataset.id for dataset in datasets]
charts = (
db.session.query(Slice)
.filter(
Slice.datasource_id.in_(dataset_ids),
Slice.datasource_type == DatasourceType.TABLE,
)
.all()
)
chart_ids = [chart.id for chart in charts]
dashboards = (
(
db.session.query(Dashboard)
.join(Dashboard.slices)
.filter(Slice.id.in_(chart_ids))
)
.distinct()
.all()
)
sqllab_tab_states = (
db.session.query(TabState).filter(TabState.database_id == database_id).all()
)
return {
"charts": charts,
"dashboards": dashboards,
"sqllab_tab_states": sqllab_tab_states,
}
@classmethod
def get_datasets(
cls,
database_id: int,
catalog: str | None,
schema: str | None,
) -> list[SqlaTable]:
"""
Return all datasets, optionally filtered by catalog/schema.
:param database_id: The database ID
:param catalog: The catalog name
:param schema: The schema name
:return: A list of SqlaTable objects
"""
return (
db.session.query(SqlaTable)
.filter(
SqlaTable.database_id == database_id,
SqlaTable.catalog == catalog,
SqlaTable.schema == schema,
)
.all()
)
class DatabaseUserOAuth2TokensDAO(BaseDAO[DatabaseUserOAuth2Tokens]):
"""
DAO for OAuth2 tokens.
"""
@classmethod
def get_database(cls, database_id: int) -> Database | None:
"""
Returns the database.
Note that this is different from `DatabaseDAO.find_by_id(database_id)` because
this DAO doesn't have any filters, so it can be called even for users without
database access (which is necessary for OAuth2).
"""
return db.session.query(Database).filter_by(id=database_id).one_or_none()