diff --git a/superset-frontend/src/features/databases/UploadDataModel/UploadDataModal.test.tsx b/superset-frontend/src/features/databases/UploadDataModel/UploadDataModal.test.tsx index 7c587264129..97cc9c1bb8c 100644 --- a/superset-frontend/src/features/databases/UploadDataModel/UploadDataModal.test.tsx +++ b/superset-frontend/src/features/databases/UploadDataModel/UploadDataModal.test.tsx @@ -78,11 +78,11 @@ beforeEach(() => { result: [], }); - fetchMock.get('glob:*api/v1/database/1/schemas/', { + fetchMock.get('glob:*api/v1/database/1/schemas/?q=(upload_allowed:!t)', { result: ['information_schema', 'public'], }); - fetchMock.get('glob:*api/v1/database/2/schemas/', { + fetchMock.get('glob:*api/v1/database/2/schemas/?q=(upload_allowed:!t)', { result: ['schema1', 'schema2'], }); }); diff --git a/superset-frontend/src/features/databases/UploadDataModel/index.tsx b/superset-frontend/src/features/databases/UploadDataModel/index.tsx index 7b51b7a0b49..46e9b33f39a 100644 --- a/superset-frontend/src/features/databases/UploadDataModel/index.tsx +++ b/superset-frontend/src/features/databases/UploadDataModel/index.tsx @@ -363,7 +363,7 @@ const UploadDataModal: FunctionComponent = ({ return Promise.resolve({ data: [], totalCount: 0 }); } return SupersetClient.get({ - endpoint: `/api/v1/database/${currentDatabaseId}/schemas/`, + endpoint: `/api/v1/database/${currentDatabaseId}/schemas/?q=(upload_allowed:!t)`, }).then(response => { const list = response.json.result.map((item: string) => ({ value: item, diff --git a/superset/databases/api.py b/superset/databases/api.py index fdf4d980c0b..f87935144ea 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -776,18 +776,35 @@ class DatabaseRestApi(BaseSupersetModelRestApi): if not database: return self.response_404() try: - catalog = kwargs["rison"].get("catalog") + params = kwargs["rison"] + catalog = params.get("catalog") schemas = database.get_all_schema_names( catalog=catalog, cache=database.schema_cache_enabled, cache_timeout=database.schema_cache_timeout or None, - force=kwargs["rison"].get("force", False), + force=params.get("force", False), ) schemas = security_manager.get_schemas_accessible_by_user( database, catalog, schemas, ) + if params.get("upload_allowed"): + if not database.allow_file_upload: + return self.response(200, result=[]) + if allowed_schemas := database.get_schema_access_for_file_upload(): + # some databases might return the list of schemas in uppercase, + # while the list of allowed schemas is manually inputted so + # could be lowercase + allowed_schemas = {schema.lower() for schema in allowed_schemas} + return self.response( + 200, + result=[ + schema + for schema in schemas + if schema.lower() in allowed_schemas + ], + ) return self.response(200, result=list(schemas)) except OperationalError: return self.response( diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index f225a1458d5..b91dbd25d5d 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -64,6 +64,7 @@ database_schemas_query_schema = { "type": "object", "properties": { "force": {"type": "boolean"}, + "upload_allowed": {"type": "boolean"}, "catalog": {"type": "string"}, }, } diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index a62888e0d01..fc558e519fb 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2088,6 +2088,93 @@ class TestDatabaseApi(SupersetTestCase): ) assert rv.status_code == 400 + def test_database_schemas_upload_allowed_filter(self): + """ + Database API: Test database schemas when filtering for upload allowed + and there is not schema restriction + """ + with self.create_app().app_context(): + example_db = get_example_database() + + extra = { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": [], + } + self.login(ADMIN_USERNAME) + database = self.insert_database( + "database_with_upload", + example_db.sqlalchemy_uri_decrypted, + extra=json.dumps(extra), + allow_file_upload=True, + ) + db.session.commit() + yield database + + mock_schemas = ["schema_1", "schema_2", "schema_3"] + mock.patch.object( + database, "get_all_schema_names", return_value=mock_schemas + ) + arguments = {"upload_allowed": True} + uri = f"api/v1/database/{database.id}/schemas/?q={prison.dumps(arguments)}" + rv = self.client.get(uri) + data = json.loads(rv.data.decode("utf-8")) + assert data["result"] == mock_schemas + db.session.delete(database) + db.session.commit() + + def test_database_schemas_upload_allowed_filter_specific_schemas(self): + """ + Database API: Test database schemas when filtering for upload allowed + with an schema restriction set + """ + with self.create_app().app_context(): + example_db = get_example_database() + + extra = { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": ["schema_2"], + } + self.login(ADMIN_USERNAME) + database = self.insert_database( + "database_with_upload", + example_db.sqlalchemy_uri_decrypted, + extra=json.dumps(extra), + allow_file_upload=True, + ) + db.session.commit() + yield database + + mock.patch.object( + database, + "get_all_schema_names", + return_value=["schema_1", "schema_2", "schema_3"], + ) + arguments = {"upload_allowed": True} + uri = f"api/v1/database/{database.id}/schemas/?q={prison.dumps(arguments)}" + rv = self.client.get(uri) + data = json.loads(rv.data.decode("utf-8")) + assert data["result"] == ["schema_2"] + db.session.delete(database) + db.session.commit() + + def test_database_schemas_upload_allowed_filter_disabled(self): + """ + Database API: Test database schemas when filtering for upload allowed + for a DB connection that has file uploads disabled + """ + database = db.session.query(Database).filter_by(database_name="examples").one() + self.login(ADMIN_USERNAME) + arguments = {"upload_allowed": True} + uri = f"api/v1/database/{database.id}/schemas/?q={prison.dumps(arguments)}" + rv = self.client.get(uri) + assert rv.status_code == 200 + data = json.loads(rv.data.decode("utf-8")) + assert data["result"] == [] + def test_database_tables(self): """ Database API: Test database tables