chore(command): Organize Commands according to SIP-92 (#25850)

This commit is contained in:
John Bodley
2023-11-22 11:55:54 -08:00
committed by GitHub
parent 984c278c4c
commit 07bcfa9b5f
265 changed files with 786 additions and 808 deletions

View File

@@ -29,16 +29,9 @@ from marshmallow import ValidationError
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
from superset import app, event_logger
from superset.commands.importers.exceptions import (
IncorrectFormatError,
NoValidFilesFoundError,
)
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.daos.database import DatabaseDAO
from superset.databases.commands.create import CreateDatabaseCommand
from superset.databases.commands.delete import DeleteDatabaseCommand
from superset.databases.commands.exceptions import (
from superset.commands.database.create import CreateDatabaseCommand
from superset.commands.database.delete import DeleteDatabaseCommand
from superset.commands.database.exceptions import (
DatabaseConnectionFailedError,
DatabaseCreateFailedError,
DatabaseDeleteDatasetsExistFailedError,
@@ -49,13 +42,26 @@ from superset.databases.commands.exceptions import (
DatabaseUpdateFailedError,
InvalidParametersError,
)
from superset.databases.commands.export import ExportDatabasesCommand
from superset.databases.commands.importers.dispatcher import ImportDatabasesCommand
from superset.databases.commands.tables import TablesDatabaseCommand
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.commands.update import UpdateDatabaseCommand
from superset.databases.commands.validate import ValidateDatabaseParametersCommand
from superset.databases.commands.validate_sql import ValidateSQLCommand
from superset.commands.database.export import ExportDatabasesCommand
from superset.commands.database.importers.dispatcher import ImportDatabasesCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelNotFoundError,
)
from superset.commands.database.tables import TablesDatabaseCommand
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
from superset.commands.database.update import UpdateDatabaseCommand
from superset.commands.database.validate import ValidateDatabaseParametersCommand
from superset.commands.database.validate_sql import ValidateSQLCommand
from superset.commands.importers.exceptions import (
IncorrectFormatError,
NoValidFilesFoundError,
)
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.daos.database import DatabaseDAO
from superset.databases.decorators import check_datasource_access
from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter
from superset.databases.schemas import (
@@ -79,12 +85,6 @@ from superset.databases.schemas import (
ValidateSQLRequest,
ValidateSQLResponse,
)
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelNotFoundError,
)
from superset.databases.utils import get_table_metadata
from superset.db_engine_specs import get_available_engine_specs
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType

View File

@@ -1,16 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -1,152 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Optional
from flask import current_app
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from superset import is_feature_enabled
from superset.commands.base import BaseCommand
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.commands.exceptions import (
DatabaseConnectionFailedError,
DatabaseCreateFailedError,
DatabaseExistsValidationError,
DatabaseInvalidError,
DatabaseRequiredFieldValidationError,
)
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelInvalidError,
)
from superset.exceptions import SupersetErrorsException
from superset.extensions import db, event_logger, security_manager
logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
class CreateDatabaseCommand(BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
self.validate()
try:
# Test connection before starting create transaction
TestConnectionDatabaseCommand(self._properties).run()
except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
# So we can show the original message
raise ex
except Exception as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
raise DatabaseConnectionFailedError() from ex
# when creating a new database we don't need to unmask encrypted extra
self._properties["encrypted_extra"] = self._properties.pop(
"masked_encrypted_extra",
"{}",
)
try:
database = DatabaseDAO.create(attributes=self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = None
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
try:
# So database.id is not None
db.session.flush()
ssh_tunnel = CreateSSHTunnelCommand(
database.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
# So we can show the original message
raise ex
except Exception as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
raise DatabaseCreateFailedError() from ex
# adding a new database we always want to force refresh schema list
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
for schema in schemas:
security_manager.add_permission_view_menu(
"schema_access", security_manager.get_schema_perm(database, schema)
)
db.session.commit()
except DAOCreateFailedError as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=database.db_engine_spec.__name__,
)
raise DatabaseCreateFailedError() from ex
if ssh_tunnel:
stats_logger.incr("db_creation_success.ssh_tunnel")
return database
def validate(self) -> None:
exceptions: list[ValidationError] = []
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
database_name: Optional[str] = self._properties.get("database_name")
if not sqlalchemy_uri:
exceptions.append(DatabaseRequiredFieldValidationError("sqlalchemy_uri"))
if not database_name:
exceptions.append(DatabaseRequiredFieldValidationError("database_name"))
else:
# Check database_name uniqueness
if not DatabaseDAO.validate_uniqueness(database_name):
exceptions.append(DatabaseExistsValidationError())
if exceptions:
exception = DatabaseInvalidError()
exception.extend(exceptions)
event_logger.log_with_context(
# pylint: disable=consider-using-f-string
action="db_connection_failed.{}.{}".format(
exception.__class__.__name__,
".".join(exception.get_list_classnames()),
)
)
raise exception

View File

@@ -1,66 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional
from flask_babel import lazy_gettext as _
from superset.commands.base import BaseCommand
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.databases.commands.exceptions import (
DatabaseDeleteDatasetsExistFailedError,
DatabaseDeleteFailedError,
DatabaseDeleteFailedReportsExistError,
DatabaseNotFoundError,
)
from superset.models.core import Database
logger = logging.getLogger(__name__)
class DeleteDatabaseCommand(BaseCommand):
def __init__(self, model_id: int):
self._model_id = model_id
self._model: Optional[Database] = None
def run(self) -> None:
self.validate()
assert self._model
try:
DatabaseDAO.delete([self._model])
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatabaseDeleteFailedError() from ex
def validate(self) -> None:
# Validate/populate model exists
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
# Check there are no associated ReportSchedules
if reports := ReportScheduleDAO.find_by_database_id(self._model_id):
report_names = [report.name for report in reports]
raise DatabaseDeleteFailedReportsExistError(
_(f"There are associated alerts or reports: {','.join(report_names)}")
)
# Check if there are datasets for this database
if self._model.tables:
raise DatabaseDeleteDatasetsExistFailedError()

View File

@@ -1,183 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import lazy_gettext as _
from marshmallow.validate import ValidationError
from superset.commands.exceptions import (
CommandException,
CommandInvalidError,
CreateFailedError,
DeleteFailedError,
ImportFailedError,
UpdateFailedError,
)
from superset.exceptions import SupersetErrorException, SupersetErrorsException
class DatabaseInvalidError(CommandInvalidError):
message = _("Database parameters are invalid.")
class DatabaseExistsValidationError(ValidationError):
"""
Marshmallow validation error for dataset already exists
"""
def __init__(self) -> None:
super().__init__(
_("A database with the same name already exists."),
field_name="database_name",
)
class DatabaseRequiredFieldValidationError(ValidationError):
def __init__(self, field_name: str) -> None:
super().__init__(
[_("Field is required")],
field_name=field_name,
)
class DatabaseExtraJSONValidationError(ValidationError):
"""
Marshmallow validation error for database encrypted extra must be a valid JSON
"""
def __init__(self, json_error: str = "") -> None:
super().__init__(
[
_(
"Field cannot be decoded by JSON. %(json_error)s",
json_error=json_error,
)
],
field_name="extra",
)
class DatabaseExtraValidationError(ValidationError):
"""
Marshmallow validation error for database encrypted extra must be a valid JSON
"""
def __init__(self, key: str = "") -> None:
super().__init__(
[
_(
"The metadata_params in Extra field "
"is not configured correctly. The key "
"%{key}s is invalid.",
key=key,
)
],
field_name="extra",
)
class DatabaseNotFoundError(CommandException):
message = _("Database not found.")
class DatabaseCreateFailedError(CreateFailedError):
message = _("Database could not be created.")
class DatabaseUpdateFailedError(UpdateFailedError):
message = _("Database could not be updated.")
class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors
DatabaseCreateFailedError,
DatabaseUpdateFailedError,
):
message = _("Connection failed, please check your connection settings")
class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):
message = _("Cannot delete a database that has datasets attached")
class DatabaseDeleteFailedError(DeleteFailedError):
message = _("Database could not be deleted.")
class DatabaseDeleteFailedReportsExistError(DatabaseDeleteFailedError):
message = _("There are associated alerts or reports")
class DatabaseTestConnectionFailedError(SupersetErrorsException):
status = 422
message = _("Connection failed, please check your connection settings")
class DatabaseSecurityUnsafeError(CommandInvalidError):
message = _("Stopped an unsafe database connection")
class DatabaseTestConnectionDriverError(CommandInvalidError):
message = _("Could not load database driver")
class DatabaseTestConnectionUnexpectedError(SupersetErrorsException):
status = 422
message = _("Unexpected error occurred, please check your logs for details")
class DatabaseTablesUnexpectedError(Exception):
status = 422
message = _("Unexpected error occurred, please check your logs for details")
class NoValidatorConfigFoundError(SupersetErrorException):
status = 422
message = _("no SQL validator is configured")
class NoValidatorFoundError(SupersetErrorException):
status = 422
message = _("No validator found (configured for the engine)")
class ValidatorSQLError(SupersetErrorException):
status = 422
message = _("Was unable to check your query")
class ValidatorSQLUnexpectedError(CommandException):
status = 422
message = _("An unexpected error occurred")
class ValidatorSQL400Error(SupersetErrorException):
status = 400
message = _("Was unable to check your query")
class DatabaseImportError(ImportFailedError):
message = _("Import database failed for an unknown reason")
class InvalidEngineError(SupersetErrorException):
status = 422
class DatabaseOfflineError(SupersetErrorException):
status = 422
class InvalidParametersError(SupersetErrorsException):
status = 422

View File

@@ -1,122 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import json
import logging
from typing import Any
from collections.abc import Iterator
import yaml
from superset.databases.commands.exceptions import DatabaseNotFoundError
from superset.daos.database import DatabaseDAO
from superset.commands.export.models import ExportModelsCommand
from superset.models.core import Database
from superset.utils.dict_import_export import EXPORT_VERSION
from superset.utils.file import get_filename
from superset.utils.ssh_tunnel import mask_password_info
logger = logging.getLogger(__name__)
def parse_extra(extra_payload: str) -> dict[str, Any]:
try:
extra = json.loads(extra_payload)
except json.decoder.JSONDecodeError:
logger.info("Unable to decode `extra` field: %s", extra_payload)
return {}
# Fix for DBs saved with an invalid ``schemas_allowed_for_csv_upload``
schemas_allowed_for_csv_upload = extra.get("schemas_allowed_for_csv_upload")
if isinstance(schemas_allowed_for_csv_upload, str):
extra["schemas_allowed_for_csv_upload"] = json.loads(
schemas_allowed_for_csv_upload
)
return extra
class ExportDatabasesCommand(ExportModelsCommand):
dao = DatabaseDAO
not_found = DatabaseNotFoundError
@staticmethod
def _export(
model: Database, export_related: bool = True
) -> Iterator[tuple[str, str]]:
db_file_name = get_filename(model.database_name, model.id, skip_id=True)
file_path = f"databases/{db_file_name}.yaml"
payload = model.export_to_dict(
recursive=False,
include_parent_ref=False,
include_defaults=True,
export_uuids=True,
)
# https://github.com/apache/superset/pull/16756 renamed ``allow_csv_upload``
# to ``allow_file_upload`, but we can't change the V1 schema
replacements = {"allow_file_upload": "allow_csv_upload"}
# this preserves key order, which is important
payload = {replacements.get(key, key): value for key, value in payload.items()}
# TODO (betodealmeida): move this logic to export_to_dict once this
# becomes the default export endpoint
if payload.get("extra"):
extra = payload["extra"] = parse_extra(payload["extra"])
# ``schemas_allowed_for_csv_upload`` was also renamed to
# ``schemas_allowed_for_file_upload``, we need to change to preserve the
# V1 schema
if "schemas_allowed_for_file_upload" in extra:
extra["schemas_allowed_for_csv_upload"] = extra.pop(
"schemas_allowed_for_file_upload"
)
if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(model.id):
ssh_tunnel_payload = ssh_tunnel.export_to_dict(
recursive=False,
include_parent_ref=False,
include_defaults=True,
export_uuids=False,
)
payload["ssh_tunnel"] = mask_password_info(ssh_tunnel_payload)
payload["version"] = EXPORT_VERSION
file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_path, file_content
if export_related:
for dataset in model.tables:
ds_file_name = get_filename(
dataset.table_name, dataset.id, skip_id=True
)
file_path = f"datasets/{db_file_name}/{ds_file_name}.yaml"
payload = dataset.export_to_dict(
recursive=True,
include_parent_ref=False,
include_defaults=True,
export_uuids=True,
)
payload["version"] = EXPORT_VERSION
payload["database_uuid"] = str(model.uuid)
file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_path, file_content

View File

@@ -1,16 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -1,68 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any
from marshmallow.exceptions import ValidationError
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.databases.commands.importers import v1
logger = logging.getLogger(__name__)
command_versions = [v1.ImportDatabasesCommand]
class ImportDatabasesCommand(BaseCommand):
"""
Import databases.
This command dispatches the import to different versions of the command
until it finds one that matches.
"""
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
def run(self) -> None:
# iterate over all commands until we find a version that can
# handle the contents
for version in command_versions:
command = version(self.contents, *self.args, **self.kwargs)
try:
command.run()
return
except IncorrectVersionError:
logger.debug("File not handled by command, skipping")
except (CommandInvalidError, ValidationError) as exc:
# found right version, but file is invalid
logger.info("Command failed validation")
raise exc
except Exception as exc:
# validation succeeded but something went wrong
logger.exception("Error running import command")
raise exc
raise CommandInvalidError("Could not find a valid command to import file")
def validate(self) -> None:
pass

View File

@@ -1,64 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
from superset.commands.importers.v1 import ImportModelsCommand
from superset.daos.database import DatabaseDAO
from superset.databases.commands.exceptions import DatabaseImportError
from superset.databases.commands.importers.v1.utils import import_database
from superset.databases.schemas import ImportV1DatabaseSchema
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
class ImportDatabasesCommand(ImportModelsCommand):
"""Import databases"""
dao = DatabaseDAO
model_name = "database"
prefix = "databases/"
schemas: dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"datasets/": ImportV1DatasetSchema(),
}
import_error = DatabaseImportError
@staticmethod
def _import(
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# first import databases
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=overwrite)
database_ids[str(database.uuid)] = database.id
# import related datasets
for file_name, config in configs.items():
if (
file_name.startswith("datasets/")
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
# overwrite=False prevents deleting any non-imported columns/metrics
import_dataset(session, config, overwrite=False)

View File

@@ -1,78 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from typing import Any
from sqlalchemy.orm import Session
from superset import app, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.exceptions import SupersetSecurityException
from superset.models.core import Database
from superset.security.analytics_db_safety import check_sqlalchemy_uri
def import_database(
session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Database:
can_write = ignore_permissions or security_manager.can_access(
"can_write",
"Database",
)
existing = session.query(Database).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite or not can_write:
return existing
config["id"] = existing.id
elif not can_write:
raise ImportFailedError(
"Database doesn't exist and user doesn't have permission to create databases"
)
# Check if this URI is allowed
if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"]:
try:
check_sqlalchemy_uri(make_url_safe(config["sqlalchemy_uri"]))
except SupersetSecurityException as exc:
raise ImportFailedError(exc.message) from exc
# https://github.com/apache/superset/pull/16756 renamed ``csv`` to ``file``.
config["allow_file_upload"] = config.pop("allow_csv_upload")
if "schemas_allowed_for_csv_upload" in config["extra"]:
config["extra"]["schemas_allowed_for_file_upload"] = config["extra"].pop(
"schemas_allowed_for_csv_upload"
)
# TODO (betodealmeida): move this logic to import_from_dict
config["extra"] = json.dumps(config["extra"])
# Before it gets removed in import_from_dict
ssh_tunnel = config.pop("ssh_tunnel", None)
database = Database.import_from_dict(session, config, recursive=False)
if database.id is None:
session.flush()
if ssh_tunnel:
ssh_tunnel["database_id"] = database.id
SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False)
return database

View File

@@ -1,123 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, cast
from sqlalchemy.orm import lazyload, load_only
from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable
from superset.daos.database import DatabaseDAO
from superset.databases.commands.exceptions import (
DatabaseNotFoundError,
DatabaseTablesUnexpectedError,
)
from superset.exceptions import SupersetException
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.utils.core import DatasourceName
logger = logging.getLogger(__name__)
class TablesDatabaseCommand(BaseCommand):
_model: Database
def __init__(self, db_id: int, schema_name: str, force: bool):
self._db_id = db_id
self._schema_name = schema_name
self._force = force
def run(self) -> dict[str, Any]:
self.validate()
try:
tables = security_manager.get_datasources_accessible_by_user(
database=self._model,
schema=self._schema_name,
datasource_names=sorted(
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_table_names_in_schema(
schema=self._schema_name,
force=self._force,
cache=self._model.table_cache_enabled,
cache_timeout=self._model.table_cache_timeout,
)
),
)
views = security_manager.get_datasources_accessible_by_user(
database=self._model,
schema=self._schema_name,
datasource_names=sorted(
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_view_names_in_schema(
schema=self._schema_name,
force=self._force,
cache=self._model.table_cache_enabled,
cache_timeout=self._model.table_cache_timeout,
)
),
)
extra_dict_by_name = {
table.name: table.extra_dict
for table in (
db.session.query(SqlaTable)
.filter(
SqlaTable.database_id == self._model.id,
SqlaTable.schema == self._schema_name,
)
.options(
load_only(
SqlaTable.schema, SqlaTable.table_name, SqlaTable.extra
),
lazyload(SqlaTable.columns),
lazyload(SqlaTable.metrics),
)
).all()
}
options = sorted(
[
{
"value": table.table,
"type": "table",
"extra": extra_dict_by_name.get(table.table, None),
}
for table in tables
]
+ [
{
"value": view.table,
"type": "view",
}
for view in views
],
key=lambda item: item["value"],
)
payload = {"count": len(tables) + len(views), "result": options}
return payload
except SupersetException as ex:
raise ex
except Exception as ex:
raise DatabaseTablesUnexpectedError(ex) from ex
def validate(self) -> None:
self._model = cast(Database, DatabaseDAO.find_by_id(self._db_id))
if not self._model:
raise DatabaseNotFoundError()

View File

@@ -1,231 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import sqlite3
from contextlib import closing
from typing import Any, Optional
from flask import current_app as app
from flask_babel import gettext as _
from func_timeout import func_timeout, FunctionTimedOut
from sqlalchemy.engine import Engine
from sqlalchemy.exc import DBAPIError, NoSuchModuleError
from superset import is_feature_enabled
from superset.commands.base import BaseCommand
from superset.daos.database import DatabaseDAO, SSHTunnelDAO
from superset.databases.commands.exceptions import (
DatabaseSecurityUnsafeError,
DatabaseTestConnectionDriverError,
DatabaseTestConnectionUnexpectedError,
)
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelingNotEnabledError,
)
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetErrorType
from superset.exceptions import (
SupersetErrorsException,
SupersetSecurityException,
SupersetTimeoutException,
)
from superset.extensions import event_logger
from superset.models.core import Database
from superset.utils.ssh_tunnel import unmask_password_info
logger = logging.getLogger(__name__)
def get_log_connection_action(
action: str, ssh_tunnel: Optional[Any], exc: Optional[Exception] = None
) -> str:
action_modified = action
if exc:
action_modified += f".{exc.__class__.__name__}"
if ssh_tunnel:
action_modified += ".ssh_tunnel"
return action_modified
class TestConnectionDatabaseCommand(BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
self._model: Optional[Database] = None
def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches
self.validate()
ex_str = ""
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted
ssh_tunnel = self._properties.get("ssh_tunnel")
# context for error messages
url = make_url_safe(uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
serialized_encrypted_extra = self._properties.get(
"masked_encrypted_extra",
"{}",
)
if self._model:
serialized_encrypted_extra = (
self._model.db_engine_spec.unmask_encrypted_extra(
self._model.encrypted_extra,
serialized_encrypted_extra,
)
)
try:
database = DatabaseDAO.build_db_for_connection_test(
server_cert=self._properties.get("server_cert", ""),
extra=self._properties.get("extra", "{}"),
impersonate_user=self._properties.get("impersonate_user", False),
encrypted_extra=serialized_encrypted_extra,
)
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
# Generate tunnel if present in the properties
if ssh_tunnel:
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
# If there's an existing tunnel for that DB we need to use the stored
# password, private_key and private_key_password instead
if ssh_tunnel_id := ssh_tunnel.pop("id", None):
if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id):
ssh_tunnel = unmask_password_info(
ssh_tunnel, existing_ssh_tunnel
)
ssh_tunnel = SSHTunnel(**ssh_tunnel)
event_logger.log_with_context(
action=get_log_connection_action("test_connection_attempt", ssh_tunnel),
engine=database.db_engine_spec.__name__,
)
def ping(engine: Engine) -> bool:
with closing(engine.raw_connection()) as conn:
return engine.dialect.do_ping(conn)
with database.get_sqla_engine_with_context(
override_ssh_tunnel=ssh_tunnel
) as engine:
try:
alive = func_timeout(
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
ping,
args=(engine,),
)
except (sqlite3.ProgrammingError, RuntimeError):
# SQLite can't run on a separate thread, so ``func_timeout`` fails
# RuntimeError catches the equivalent error from duckdb.
alive = engine.dialect.do_ping(engine)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database settings, "
"and ensure that your database is accepting connections, "
"then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception as ex: # pylint: disable=broad-except
alive = False
# So we stop losing the original message if any
ex_str = str(ex)
if not alive:
raise DBAPIError(ex_str or None, None, None)
# Log succesful connection test with engine
event_logger.log_with_context(
action=get_log_connection_action("test_connection_success", ssh_tunnel),
engine=database.db_engine_spec.__name__,
)
except (NoSuchModuleError, ModuleNotFoundError) as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
raise DatabaseTestConnectionDriverError(
message=_("Could not load database driver: {}").format(
database.db_engine_spec.__name__
),
) from ex
except DBAPIError as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
# check for custom errors (wrong username, wrong password, etc)
errors = database.db_engine_spec.extract_errors(ex, context)
raise SupersetErrorsException(errors) from ex
except SupersetSecurityException as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
raise DatabaseSecurityUnsafeError(message=str(ex)) from ex
except SupersetTimeoutException as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
# bubble up the exception to return a 408
raise ex
except SSHTunnelingNotEnabledError as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
# bubble up the exception to return a 400
raise ex
except Exception as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionUnexpectedError(errors) from ex
def validate(self) -> None:
if (database_name := self._properties.get("database_name")) is not None:
self._model = DatabaseDAO.get_database_by_name(database_name)

View File

@@ -1,182 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from superset import is_feature_enabled
from superset.commands.base import BaseCommand
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError
from superset.databases.commands.exceptions import (
DatabaseConnectionFailedError,
DatabaseExistsValidationError,
DatabaseInvalidError,
DatabaseNotFoundError,
DatabaseUpdateFailedError,
)
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelInvalidError,
SSHTunnelUpdateFailedError,
)
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.utils.core import DatasourceType
logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand):
def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
def run(self) -> Model:
self.validate()
if not self._model:
raise DatabaseNotFoundError()
old_database_name = self._model.database_name
# unmask ``encrypted_extra``
self._properties[
"encrypted_extra"
] = self._model.db_engine_spec.unmask_encrypted_extra(
self._model.encrypted_extra,
self._properties.pop("masked_encrypted_extra", "{}"),
)
try:
database = DatabaseDAO.update(self._model, self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
if existing_ssh_tunnel_model is None:
# We couldn't found an existing tunnel so we need to create one
try:
CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
else:
# We found an existing tunnel so we need to update it
try:
UpdateSSHTunnelCommand(
existing_ssh_tunnel_model.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
# adding a new database we always want to force refresh schema list
# TODO Improve this simplistic implementation for catching DB conn fails
try:
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
# Update database schema permissions
new_schemas: list[str] = []
for schema in schemas:
old_view_menu_name = security_manager.get_schema_perm(
old_database_name, schema
)
new_view_menu_name = security_manager.get_schema_perm(
database.database_name, schema
)
schema_pvm = security_manager.find_permission_view_menu(
"schema_access", old_view_menu_name
)
# Update the schema permission if the database name changed
if schema_pvm and old_database_name != database.database_name:
schema_pvm.view_menu.name = new_view_menu_name
self._propagate_schema_permissions(
old_view_menu_name, new_view_menu_name
)
else:
new_schemas.append(schema)
for schema in new_schemas:
security_manager.add_permission_view_menu(
"schema_access", security_manager.get_schema_perm(database, schema)
)
db.session.commit()
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
raise DatabaseUpdateFailedError() from ex
return database
@staticmethod
def _propagate_schema_permissions(
old_view_menu_name: str, new_view_menu_name: str
) -> None:
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
SqlaTable,
)
from superset.models.slice import ( # pylint: disable=import-outside-toplevel
Slice,
)
# Update schema_perm on all datasets
datasets = (
db.session.query(SqlaTable)
.filter(SqlaTable.schema_perm == old_view_menu_name)
.all()
)
for dataset in datasets:
dataset.schema_perm = new_view_menu_name
charts = db.session.query(Slice).filter(
Slice.datasource_type == DatasourceType.TABLE,
Slice.datasource_id == dataset.id,
)
# Update schema_perm on all charts
for chart in charts:
chart.schema_perm = new_view_menu_name
def validate(self) -> None:
exceptions: list[ValidationError] = []
# Validate/populate model exists
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
database_name: Optional[str] = self._properties.get("database_name")
if database_name:
# Check database_name uniqueness
if not DatabaseDAO.validate_update_uniqueness(
self._model_id, database_name
):
exceptions.append(DatabaseExistsValidationError())
if exceptions:
raise DatabaseInvalidError(exceptions=exceptions)

View File

@@ -1,132 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from contextlib import closing
from typing import Any, Optional
from flask_babel import gettext as __
from superset.commands.base import BaseCommand
from superset.daos.database import DatabaseDAO
from superset.databases.commands.exceptions import (
DatabaseOfflineError,
DatabaseTestConnectionFailedError,
InvalidEngineError,
InvalidParametersError,
)
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_spec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import event_logger
from superset.models.core import Database
BYPASS_VALIDATION_ENGINES = {"bigquery"}
class ValidateDatabaseParametersCommand(BaseCommand):
def __init__(self, properties: dict[str, Any]):
self._properties = properties.copy()
self._model: Optional[Database] = None
def run(self) -> None:
self.validate()
engine = self._properties["engine"]
driver = self._properties.get("driver")
if engine in BYPASS_VALIDATION_ENGINES:
# Skip engines that are only validated onCreate
return
engine_spec = get_engine_spec(engine, driver)
if not hasattr(engine_spec, "parameters_schema"):
raise InvalidEngineError(
SupersetError(
message=__(
'Engine "%(engine)s" cannot be configured through parameters.',
engine=engine,
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
),
)
# perform initial validation
errors = engine_spec.validate_parameters(self._properties) # type: ignore
if errors:
event_logger.log_with_context(action="validation_error", engine=engine)
raise InvalidParametersError(errors)
serialized_encrypted_extra = self._properties.get(
"masked_encrypted_extra",
"{}",
)
if self._model:
serialized_encrypted_extra = engine_spec.unmask_encrypted_extra(
self._model.encrypted_extra,
serialized_encrypted_extra,
)
try:
encrypted_extra = json.loads(serialized_encrypted_extra)
except json.decoder.JSONDecodeError:
encrypted_extra = {}
# try to connect
sqlalchemy_uri = engine_spec.build_sqlalchemy_uri( # type: ignore
self._properties.get("parameters"),
encrypted_extra,
)
if self._model and sqlalchemy_uri == self._model.safe_sqlalchemy_uri():
sqlalchemy_uri = self._model.sqlalchemy_uri_decrypted
database = DatabaseDAO.build_db_for_connection_test(
server_cert=self._properties.get("server_cert", ""),
extra=self._properties.get("extra", "{}"),
impersonate_user=self._properties.get("impersonate_user", False),
encrypted_extra=serialized_encrypted_extra,
)
database.set_sqlalchemy_uri(sqlalchemy_uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
alive = False
with database.get_sqla_engine_with_context() as engine:
try:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
"password": url.password,
"port": url.port,
"username": url.username,
"database": url.database,
}
errors = database.db_engine_spec.extract_errors(ex, context)
raise DatabaseTestConnectionFailedError(errors) from ex
if not alive:
raise DatabaseOfflineError(
SupersetError(
message=__("Database is offline."),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
),
)
def validate(self) -> None:
if (database_id := self._properties.get("id")) is not None:
self._model = DatabaseDAO.find_by_id(database_id)

View File

@@ -1,117 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import re
from typing import Any, Optional
from flask import current_app
from flask_babel import gettext as __
from superset.commands.base import BaseCommand
from superset.daos.database import DatabaseDAO
from superset.databases.commands.exceptions import (
DatabaseNotFoundError,
NoValidatorConfigFoundError,
NoValidatorFoundError,
ValidatorSQL400Error,
ValidatorSQLError,
ValidatorSQLUnexpectedError,
)
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.core import Database
from superset.sql_validators import get_validator_by_name
from superset.sql_validators.base import BaseSQLValidator
from superset.utils import core as utils
logger = logging.getLogger(__name__)
class ValidateSQLCommand(BaseCommand):
def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
self._validator: Optional[type[BaseSQLValidator]] = None
def run(self) -> list[dict[str, Any]]:
"""
Validates a SQL statement
:return: A List of SQLValidationAnnotation
:raises: DatabaseNotFoundError, NoValidatorConfigFoundError
NoValidatorFoundError, ValidatorSQLUnexpectedError, ValidatorSQLError
ValidatorSQL400Error
"""
self.validate()
if not self._validator or not self._model:
raise ValidatorSQLUnexpectedError()
sql = self._properties["sql"]
schema = self._properties.get("schema")
try:
timeout = current_app.config["SQLLAB_VALIDATION_TIMEOUT"]
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
with utils.timeout(seconds=timeout, error_message=timeout_msg):
errors = self._validator.validate(sql, schema, self._model)
return [err.to_dict() for err in errors]
except Exception as ex:
logger.exception(ex)
superset_error = SupersetError(
message=__(
"%(validator)s was unable to check your query.\n"
"Please recheck your query.\n"
"Exception: %(ex)s",
validator=self._validator.name,
ex=ex,
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
)
# Return as a 400 if the database error message says we got a 4xx error
if re.search(r"([\W]|^)4\d{2}([\W]|$)", str(ex)):
raise ValidatorSQL400Error(superset_error) from ex
raise ValidatorSQLError(superset_error) from ex
def validate(self) -> None:
# Validate/populate model exists
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
spec = self._model.db_engine_spec
validators_by_engine = current_app.config["SQL_VALIDATORS_BY_ENGINE"]
if not validators_by_engine or spec.engine not in validators_by_engine:
raise NoValidatorConfigFoundError(
SupersetError(
message=__(f"no SQL validator is configured for {spec.engine}"),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
),
)
validator_name = validators_by_engine[spec.engine]
self._validator = get_validator_by_name(validator_name)
if not self._validator:
raise NoValidatorFoundError(
SupersetError(
message=__(
f"No validator named {validator_name} found "
f"(configured for the {spec.engine} engine)"
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
),
)

View File

@@ -28,13 +28,13 @@ from marshmallow.validate import Length, ValidationError
from sqlalchemy import MetaData
from superset import db, is_feature_enabled
from superset.constants import PASSWORD_MASK
from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.databases.ssh_tunnel.commands.exceptions import (
from superset.commands.database.exceptions import DatabaseInvalidError
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelingNotEnabledError,
SSHTunnelInvalidCredentials,
SSHTunnelMissingCredentials,
)
from superset.constants import PASSWORD_MASK
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException

View File

@@ -1,16 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -1,91 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from superset.commands.base import BaseCommand
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelInvalidError,
SSHTunnelRequiredFieldValidationError,
)
from superset.extensions import db, event_logger
logger = logging.getLogger(__name__)
class CreateSSHTunnelCommand(BaseCommand):
def __init__(self, database_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._properties["database_id"] = database_id
def run(self) -> Model:
try:
# Start nested transaction since we are always creating the tunnel
# through a DB command (Create or Update). Without this, we cannot
# safely rollback changes to databases if any, i.e, things like
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
db.session.begin_nested()
self.validate()
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
except DAOCreateFailedError as ex:
# Rollback nested transaction
db.session.rollback()
raise SSHTunnelCreateFailedError() from ex
except SSHTunnelInvalidError as ex:
# Rollback nested transaction
db.session.rollback()
raise ex
def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER
exceptions: list[ValidationError] = []
database_id: Optional[int] = self._properties.get("database_id")
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
username: Optional[str] = self._properties.get("username")
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
if not database_id:
exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
if not server_port:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
if not username:
exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
if private_key_password and private_key is None:
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
if exceptions:
exception = SSHTunnelInvalidError()
exception.extend(exceptions)
event_logger.log_with_context(
# pylint: disable=consider-using-f-string
action="ssh_tunnel_creation_failed.{}.{}".format(
exception.__class__.__name__,
".".join(exception.get_list_classnames()),
)
)
raise exception

View File

@@ -1,54 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional
from superset import is_feature_enabled
from superset.commands.base import BaseCommand
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelNotFoundError,
)
from superset.databases.ssh_tunnel.models import SSHTunnel
logger = logging.getLogger(__name__)
class DeleteSSHTunnelCommand(BaseCommand):
def __init__(self, model_id: int):
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
def run(self) -> None:
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
self.validate()
assert self._model
try:
SSHTunnelDAO.delete([self._model])
except DAODeleteFailedError as ex:
raise SSHTunnelDeleteFailedError() from ex
def validate(self) -> None:
# Validate/populate model exists
self._model = SSHTunnelDAO.find_by_id(self._model_id)
if not self._model:
raise SSHTunnelNotFoundError()

View File

@@ -1,67 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import lazy_gettext as _
from marshmallow import ValidationError
from superset.commands.exceptions import (
CommandException,
CommandInvalidError,
DeleteFailedError,
UpdateFailedError,
)
class SSHTunnelDeleteFailedError(DeleteFailedError):
message = _("SSH Tunnel could not be deleted.")
class SSHTunnelNotFoundError(CommandException):
status = 404
message = _("SSH Tunnel not found.")
class SSHTunnelInvalidError(CommandInvalidError):
message = _("SSH Tunnel parameters are invalid.")
class SSHTunnelUpdateFailedError(UpdateFailedError):
message = _("SSH Tunnel could not be updated.")
class SSHTunnelCreateFailedError(CommandException):
message = _("Creating SSH Tunnel failed for an unknown reason")
class SSHTunnelingNotEnabledError(CommandException):
status = 400
message = _("SSH Tunneling is not enabled")
class SSHTunnelRequiredFieldValidationError(ValidationError):
def __init__(self, field_name: str) -> None:
super().__init__(
[_("Field is required")],
field_name=field_name,
)
class SSHTunnelMissingCredentials(CommandInvalidError):
message = _("Must provide credentials for the SSH Tunnel")
class SSHTunnelInvalidCredentials(CommandInvalidError):
message = _("Cannot have multiple credentials for the SSH Tunnel")

View File

@@ -1,63 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from superset.commands.base import BaseCommand
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelInvalidError,
SSHTunnelNotFoundError,
SSHTunnelRequiredFieldValidationError,
SSHTunnelUpdateFailedError,
)
from superset.databases.ssh_tunnel.models import SSHTunnel
logger = logging.getLogger(__name__)
class UpdateSSHTunnelCommand(BaseCommand):
def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
def run(self) -> Model:
self.validate()
try:
if self._model is not None: # So we dont get incompatible types error
tunnel = SSHTunnelDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
raise SSHTunnelUpdateFailedError() from ex
return tunnel
def validate(self) -> None:
# Validate/populate model exists
self._model = SSHTunnelDAO.find_by_id(self._model_id)
if not self._model:
raise SSHTunnelNotFoundError()
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
if private_key_password and private_key is None:
raise SSHTunnelInvalidError(
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
)

View File

@@ -18,7 +18,7 @@ from typing import Any, Optional, Union
from sqlalchemy.engine.url import make_url, URL
from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.commands.database.exceptions import DatabaseInvalidError
def get_foreign_keys_metadata(