diff --git a/superset/commands/database/sync_permissions.py b/superset/commands/database/sync_permissions.py index c4213da98fe..4f041ce94cb 100644 --- a/superset/commands/database/sync_permissions.py +++ b/superset/commands/database/sync_permissions.py @@ -140,13 +140,7 @@ class SyncPermissionsCommand(BaseCommand): """ Syncs the permissions for a DB connection. """ - catalogs = ( - self._get_catalog_names() - if self.db_connection.db_engine_spec.supports_catalog - else [None] - ) - - for catalog in catalogs: + for catalog in self._get_catalog_names(): try: schemas = self._get_schema_names(catalog) @@ -192,15 +186,29 @@ class SyncPermissionsCommand(BaseCommand): if self.old_db_connection_name != self.db_connection.database_name: self._rename_database_in_permissions(catalog, schemas) - def _get_catalog_names(self) -> set[str]: + def _get_catalog_names(self) -> set[str | None]: """ Helper method to load catalogs. """ + if not self.db_connection.db_engine_spec.supports_catalog: + return {None} + try: - return self.db_connection.get_all_catalog_names( - force=True, - ssh_tunnel=self.db_connection_ssh_tunnel, - ) + # Adding permissions to all catalogs (and all their schemas) can take a long + # time (minutes, while importing a chart, eg). If the database does not + # support cross-catalog queries (like Postgres), and the multi-catalog + # feature is not enabled, then we only need to add permissions to the + # default catalog. + if ( + self.db_connection.db_engine_spec.supports_cross_catalog_queries + or self.db_connection.allow_multi_catalog + ): + return self.db_connection.get_all_catalog_names( + force=True, + ssh_tunnel=self.db_connection_ssh_tunnel, + ) + else: + return {self.db_connection.get_default_catalog()} except OAuth2RedirectError: # raise OAuth2 exceptions as-is raise diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index eaecb74020f..502c8ac82b4 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -428,6 +428,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # Can the catalog be changed on a per-query basis? supports_dynamic_catalog = False + # Does the DB engine spec support cross-catalog queries? + supports_cross_catalog_queries = False + # Does the engine supports OAuth 2.0? This requires logic to be added to one of the # the user impersonation methods to handle personal tokens. supports_oauth2 = False diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index cf5cfaad511..bd7eb6faaf1 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -136,7 +136,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met allows_hidden_cc_in_orderby = True - supports_catalog = supports_dynamic_catalog = True + supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True # when editing the database, mask this field in `encrypted_extra` # pylint: disable=invalid-name @@ -539,7 +539,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met ] @classmethod - def get_default_catalog(cls, database: Database) -> str | None: + def get_default_catalog(cls, database: Database) -> str: """ Get the default catalog. """ diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index fe7f9084009..48144bb35bc 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -373,7 +373,10 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec): "extra", } - supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True + supports_dynamic_schema = True + supports_catalog = True + supports_dynamic_catalog = True + supports_cross_catalog_queries = True @classmethod def build_sqlalchemy_uri( # type: ignore @@ -433,10 +436,7 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec): return spec.to_dict()["components"]["schemas"][cls.__name__] @classmethod - def get_default_catalog( - cls, - database: Database, - ) -> str | None: + def get_default_catalog(cls, database: Database) -> str: """ Return the default catalog. diff --git a/superset/db_engine_specs/doris.py b/superset/db_engine_specs/doris.py index d14d21661b5..c00dd8040d2 100644 --- a/superset/db_engine_specs/doris.py +++ b/superset/db_engine_specs/doris.py @@ -113,7 +113,7 @@ class DorisEngineSpec(MySQLEngineSpec): ) encryption_parameters = {"ssl": "0"} supports_dynamic_schema = True - supports_catalog = supports_dynamic_catalog = True + supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True column_type_mappings = ( # type: ignore ( diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 16117062b06..b86172cb3a3 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -319,7 +319,7 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): return uri, connect_args @classmethod - def get_default_catalog(cls, database: Database) -> str | None: + def get_default_catalog(cls, database: Database) -> str: """ Return the default catalog for a given database. """ diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index fa1bad8c2ca..10f65cf004b 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -162,7 +162,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): """ supports_dynamic_schema = True - supports_catalog = supports_dynamic_catalog = True + supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True column_type_mappings = ( ( diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 5c9c42ea764..1bd604aa35d 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -88,7 +88,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): sqlalchemy_uri_placeholder = "snowflake://" supports_dynamic_schema = True - supports_catalog = supports_dynamic_catalog = True + supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True # pylint: disable=invalid-name encrypted_extra_sensitive_fields = { @@ -189,7 +189,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): return parse.unquote(database.split("/")[1]) @classmethod - def get_default_catalog(cls, database: "Database") -> Optional[str]: + def get_default_catalog(cls, database: "Database") -> str: """ Return the default catalog. """ diff --git a/tests/unit_tests/commands/databases/sync_permissions_test.py b/tests/unit_tests/commands/databases/sync_permissions_test.py index 10d0723f504..e597e384db9 100644 --- a/tests/unit_tests/commands/databases/sync_permissions_test.py +++ b/tests/unit_tests/commands/databases/sync_permissions_test.py @@ -231,7 +231,7 @@ def test_sync_permissions_command_async_mode_new_db_name( async_task_mock.delay.assert_called_once_with(1, "admin", "Old Name") -def test_resync_permissions_command_get_catalogs(database_with_catalog: MagicMock): +def test_sync_permissions_command_get_catalogs(database_with_catalog: MagicMock): """ Test the ``_get_catalog_names`` method. """ @@ -239,6 +239,23 @@ def test_resync_permissions_command_get_catalogs(database_with_catalog: MagicMoc assert cmmd._get_catalog_names() == ["catalog1", "catalog2"] +def test_sync_permissions_command_get_default_catalog(database_with_catalog: MagicMock): + """ + Test ``_get_catalog_names`` when only the default one should be returned. + + When the database doesn't not support cross-catalog queries (like Postgres), we + should only return all catalogs if multi-catalog is enabled. + """ + database_with_catalog.db_engine_spec.supports_cross_catalog_queries = False + database_with_catalog.allow_multi_catalog = False + cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) + assert cmmd._get_catalog_names() == {"catalog2"} + + database_with_catalog.allow_multi_catalog = True + cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) + assert cmmd._get_catalog_names() == ["catalog1", "catalog2"] + + @pytest.mark.parametrize( ("inner_exception, outer_exception"), [ @@ -249,7 +266,7 @@ def test_resync_permissions_command_get_catalogs(database_with_catalog: MagicMoc (GenericDBException, DatabaseConnectionFailedError), ], ) -def test_resync_permissions_command_raise_on_getting_catalogs( +def test_sync_permissions_command_raise_on_getting_catalogs( inner_exception: Exception, outer_exception: Exception, database_with_catalog: MagicMock, @@ -263,7 +280,7 @@ def test_resync_permissions_command_raise_on_getting_catalogs( cmmd._get_catalog_names() -def test_resync_permissions_command_get_schemas(database_with_catalog: MagicMock): +def test_sync_permissions_command_get_schemas(database_with_catalog: MagicMock): """ Test the ``_get_schema_names`` method. """ @@ -282,7 +299,7 @@ def test_resync_permissions_command_get_schemas(database_with_catalog: MagicMock (GenericDBException, DatabaseConnectionFailedError), ], ) -def test_resync_permissions_command_raise_on_getting_schemas( +def test_sync_permissions_command_raise_on_getting_schemas( inner_exception: Exception, outer_exception: Exception, database_with_catalog: MagicMock, @@ -296,7 +313,7 @@ def test_resync_permissions_command_raise_on_getting_schemas( cmmd._get_schema_names("blah") -def test_resync_permissions_command_refresh_schemas( +def test_sync_permissions_command_refresh_schemas( mocker: MockerFixture, database_with_catalog: MagicMock ): """ @@ -319,7 +336,7 @@ def test_resync_permissions_command_refresh_schemas( ) -def test_resync_permissions_command_rename_db_in_perms( +def test_sync_permissions_command_rename_db_in_perms( mocker: MockerFixture, database_with_catalog: MagicMock ): """