diff --git a/superset/assets/spec/javascripts/sqllab/actions/sqlLab_spec.js b/superset/assets/spec/javascripts/sqllab/actions/sqlLab_spec.js index f8fab7579cf..36eb1fe6dff 100644 --- a/superset/assets/spec/javascripts/sqllab/actions/sqlLab_spec.js +++ b/superset/assets/spec/javascripts/sqllab/actions/sqlLab_spec.js @@ -331,7 +331,7 @@ describe('async actions', () => { fetchMock.delete(updateTableSchemaEndpoint, {}); fetchMock.post(updateTableSchemaEndpoint, JSON.stringify({ id: 1 })); - const getTableMetadataEndpoint = 'glob:*/superset/table/*'; + const getTableMetadataEndpoint = 'glob:*/api/v1/database/*'; fetchMock.get(getTableMetadataEndpoint, {}); const getExtraTableMetadataEndpoint = 'glob:*/superset/extra_table_metadata/*'; diff --git a/superset/assets/src/SqlLab/actions/sqlLab.js b/superset/assets/src/SqlLab/actions/sqlLab.js index 4ec670d821a..9dad773219f 100644 --- a/superset/assets/src/SqlLab/actions/sqlLab.js +++ b/superset/assets/src/SqlLab/actions/sqlLab.js @@ -934,7 +934,7 @@ export function mergeTable(table, query) { function getTableMetadata(table, query, dispatch) { return SupersetClient.get({ endpoint: encodeURI( - `/superset/table/${query.dbId}/` + + `/api/v1/database/${query.dbId}/table/` + `${encodeURIComponent(table.name)}/${encodeURIComponent( table.schema, )}/`, diff --git a/superset/views/core.py b/superset/views/core.py index 5cfa4a58d5d..0d7ccd8804c 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1943,70 +1943,6 @@ class Superset(BaseSupersetView): db.session.commit() return json_success(json.dumps({"table_id": table.id})) - @has_access - @expose("/table////") - @event_logger.log_this - def table(self, database_id, table_name, schema): - schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) - table_name = utils.parse_js_uri_path_item(table_name) - mydb = db.session.query(models.Database).filter_by(id=database_id).one() - payload_columns = [] - indexes = [] - primary_key = [] - foreign_keys = [] - try: - columns = mydb.get_columns(table_name, schema) - indexes = mydb.get_indexes(table_name, schema) - primary_key = mydb.get_pk_constraint(table_name, schema) - foreign_keys = mydb.get_foreign_keys(table_name, schema) - except Exception as e: - return json_error_response(utils.error_msg_from_exception(e)) - keys = [] - if primary_key and primary_key.get("constrained_columns"): - primary_key["column_names"] = primary_key.pop("constrained_columns") - primary_key["type"] = "pk" - keys += [primary_key] - for fk in foreign_keys: - fk["column_names"] = fk.pop("constrained_columns") - fk["type"] = "fk" - keys += foreign_keys - for idx in indexes: - idx["type"] = "index" - keys += indexes - - for col in columns: - dtype = "" - try: - dtype = "{}".format(col["type"]) - except Exception: - # sqla.types.JSON __str__ has a bug, so using __class__. - dtype = col["type"].__class__.__name__ - pass - payload_columns.append( - { - "name": col["name"], - "type": dtype.split("(")[0] if "(" in dtype else dtype, - "longType": dtype, - "keys": [k for k in keys if col["name"] in k.get("column_names")], - } - ) - tbl = { - "name": table_name, - "columns": payload_columns, - "selectStar": mydb.select_star( - table_name, - schema=schema, - show_cols=True, - indent=True, - cols=columns, - latest_partition=True, - ), - "primaryKey": primary_key, - "foreignKeys": foreign_keys, - "indexes": keys, - } - return json_success(json.dumps(tbl)) - @has_access @expose("/extra_table_metadata////") @event_logger.log_this diff --git a/superset/views/database/api.py b/superset/views/database/api.py index ce6c0076efe..bbf60a2eafb 100644 --- a/superset/views/database/api.py +++ b/superset/views/database/api.py @@ -14,20 +14,107 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Dict, List, Optional + +from flask_appbuilder.api import expose, protect, safe from flask_appbuilder.models.sqla.interface import SQLAInterface +from flask_babel import lazy_gettext as _ +from sqlalchemy.exc import SQLAlchemyError -import superset.models.core as models +from superset import event_logger +from superset.models.core import Database +from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item from superset.views.base_api import BaseSupersetModelRestApi +from superset.views.database.filters import DatabaseFilter +from superset.views.database.mixins import DatabaseMixin +from superset.views.database.validators import sqlalchemy_uri_validator -from .mixins import DatabaseFilter, DatabaseMixin -from .validators import sqlalchemy_uri_validator + +def get_foreign_keys_metadata( + database: Database, table_name: str, schema_name: Optional[str] +) -> List[Dict[str, Any]]: + foreign_keys = database.get_foreign_keys(table_name, schema_name) + for fk in foreign_keys: + fk["column_names"] = fk.pop("constrained_columns") + fk["type"] = "fk" + return foreign_keys + + +def get_indexes_metadata( + database: Database, table_name: str, schema_name: Optional[str] +) -> List[Dict[str, Any]]: + indexes = database.get_indexes(table_name, schema_name) + for idx in indexes: + idx["type"] = "index" + return indexes + + +def get_col_type(col: Dict) -> str: + try: + dtype = f"{col['type']}" + except Exception: # pylint: disable=broad-except + # sqla.types.JSON __str__ has a bug, so using __class__. + dtype = col["type"].__class__.__name__ + return dtype + + +def get_table_metadata( + database: Database, table_name: str, schema_name: Optional[str] +) -> Dict: + """ + Get table metadata information, including type, pk, fks. + This function raises SQLAlchemyError when a schema is not found. + + + :param database: The database model + :param table_name: Table name + :param schema_name: schema name + :return: Dict table metadata ready for API response + """ + keys: List = [] + columns = database.get_columns(table_name, schema_name) + primary_key = database.get_pk_constraint(table_name, schema_name) + if primary_key and primary_key.get("constrained_columns"): + primary_key["column_names"] = primary_key.pop("constrained_columns") + primary_key["type"] = "pk" + keys += [primary_key] + foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) + indexes = get_indexes_metadata(database, table_name, schema_name) + keys += foreign_keys + indexes + payload_columns: List[Dict] = [] + for col in columns: + dtype = get_col_type(col) + payload_columns.append( + { + "name": col["name"], + "type": dtype.split("(")[0] if "(" in dtype else dtype, + "longType": dtype, + "keys": [k for k in keys if col["name"] in k.get("column_names")], + } + ) + return { + "name": table_name, + "columns": payload_columns, + "selectStar": database.select_star( + table_name, + schema=schema_name, + show_cols=True, + indent=True, + cols=columns, + latest_partition=True, + ), + "primaryKey": primary_key, + "foreignKeys": foreign_keys, + "indexes": keys, + } class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi): - datamodel = SQLAInterface(models.Database) - include_route_methods = {"get_list"} + datamodel = SQLAInterface(Database) + include_route_methods = {"get_list", "table_metadata"} class_permission_name = "DatabaseView" + method_permission_name = {"get_list": "list", "table_metadata": "list"} resource_name = "database" allow_browser_login = True base_filters = [["id", DatabaseFilter, lambda: []]] @@ -51,3 +138,143 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi): # Removes the local limit for the page size max_page_size = -1 validators_columns = {"sqlalchemy_uri": sqlalchemy_uri_validator} + + @expose( + "//table///", methods=["GET"] + ) + @protect() + @safe + @event_logger.log_this + def table_metadata( + self, pk: int, table_name: str, schema_name: str + ): # pylint: disable=invalid-name + """ Table schema info + --- + get: + description: Get database table metadata + parameters: + - in: path + schema: + type: integer + name: pk + description: The database id + - in: path + schema: + type: string + name: table_name + description: Table name + - in: path + schema: + type: string + name: schema + description: Table schema + responses: + 200: + description: Table schema info + content: + text/plain: + schema: + type: object + properties: + columns: + type: array + description: Table columns info + items: + type: object + properties: + keys: + type: array + items: + type: string + longType: + type: string + name: + type: string + type: + type: string + foreignKeys: + type: array + description: Table list of foreign keys + items: + type: object + properties: + column_names: + type: array + items: + type: string + name: + type: string + options: + type: object + referred_columns: + type: array + items: + type: string + referred_schema: + type: string + referred_table: + type: string + type: + type: string + indexes: + type: array + description: Table list of indexes + items: + type: object + properties: + column_names: + type: array + items: + type: string + name: + type: string + options: + type: object + referred_columns: + type: array + items: + type: string + referred_schema: + type: string + referred_table: + type: string + type: + type: string + primaryKey: + type: object + properties: + column_names: + type: array + items: + type: string + name: + type: string + type: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + table_name_parsed = parse_js_uri_path_item(table_name) + schema_parsed = parse_js_uri_path_item(schema_name, eval_undefined=True) + # schemas can be None but not tables + if not table_name_parsed: + return self.response_422(message=_(f"Could not parse table name or schema")) + database: Database = self.datamodel.get(pk, self._base_filters) + if not database: + return self.response_404() + + try: + table_info: Dict = get_table_metadata( + database, table_name_parsed, schema_parsed + ) + except SQLAlchemyError as e: + return self.response_422(error_msg_from_exception(e)) + return self.response(200, **table_info) diff --git a/superset/views/database/filters.py b/superset/views/database/filters.py new file mode 100644 index 00000000000..4999e3cdd99 --- /dev/null +++ b/superset/views/database/filters.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from sqlalchemy import or_ + +from superset import security_manager +from superset.views.base import BaseFilter + + +class DatabaseFilter(BaseFilter): + # TODO(bogdan): consider caching. + def schema_access_databases(self): # noqa pylint: disable=no-self-use + found_databases = set() + for vm in security_manager.user_view_menu_names("schema_access"): + database_name, _ = security_manager.unpack_schema_perm(vm) + found_databases.add(database_name) + return found_databases + + def apply( + self, query, func + ): # noqa pylint: disable=unused-argument,arguments-differ + if security_manager.all_database_access(): + return query + database_perms = security_manager.user_view_menu_names("database_access") + # TODO(bogdan): consider adding datasource access here as well. + schema_access_databases = self.schema_access_databases() + return query.filter( + or_( + self.model.perm.in_(database_perms), + self.model.database_name.in_(schema_access_databases), + ) + ) diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py index f9cebe9f97a..219900cdeb6 100644 --- a/superset/views/database/mixins.py +++ b/superset/views/database/mixins.py @@ -18,37 +18,12 @@ import inspect from flask import Markup from flask_babel import lazy_gettext as _ -from sqlalchemy import MetaData, or_ +from sqlalchemy import MetaData from superset import security_manager from superset.exceptions import SupersetException from superset.utils import core as utils -from superset.views.base import BaseFilter - - -class DatabaseFilter(BaseFilter): - # TODO(bogdan): consider caching. - def schema_access_databases(self): # noqa pylint: disable=no-self-use - found_databases = set() - for vm in security_manager.user_view_menu_names("schema_access"): - database_name, _ = security_manager.unpack_schema_perm(vm) - found_databases.add(database_name) - return found_databases - - def apply( - self, query, func - ): # noqa pylint: disable=unused-argument,arguments-differ - if security_manager.all_database_access(): - return query - database_perms = security_manager.user_view_menu_names("database_access") - # TODO(bogdan): consider adding datasource access here as well. - schema_access_databases = self.schema_access_databases() - return query.filter( - or_( - self.model.perm.in_(database_perms), - self.model.database_name.in_(schema_access_databases), - ) - ) +from superset.views.database.filters import DatabaseFilter class DatabaseMixin: diff --git a/tests/core_tests.py b/tests/core_tests.py index 566317ecabe..4de0171b511 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -554,13 +554,6 @@ class CoreTests(SupersetTestCase): data = self.run_sql(sql, "fdaklj3ws") self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00") - def test_table_metadata(self): - maindb = utils.get_example_database() - data = self.get_json_resp(f"/superset/table/{maindb.id}/birth_names/null/") - self.assertEqual(data["name"], "birth_names") - assert len(data["columns"]) > 5 - assert data.get("selectStar").startswith("SELECT") - def test_fetch_datasource_metadata(self): self.login(username="admin") url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table" diff --git a/tests/database_api_tests.py b/tests/database_api_tests.py new file mode 100644 index 00000000000..7baeb6ca3e7 --- /dev/null +++ b/tests/database_api_tests.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for Superset""" +import json + +import prison + +from superset import db +from superset.models.core import Database +from superset.utils.core import get_example_database + +from .base_tests import SupersetTestCase + + +class DatabaseApiTests(SupersetTestCase): + def test_get_items(self): + """ + Database API: Test get items + """ + self.login(username="admin") + uri = "api/v1/database/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + expected_columns = [ + "allow_csv_upload", + "allow_ctas", + "allow_dml", + "allow_multi_schema_metadata_fetch", + "allow_run_async", + "allows_cost_estimate", + "allows_subquery", + "backend", + "database_name", + "expose_in_sqllab", + "force_ctas_schema", + "function_names", + "id", + ] + self.assertEqual(response["count"], 2) + self.assertEqual(list(response["result"][0].keys()), expected_columns) + + def test_get_items_filter(self): + fake_db = ( + db.session.query(Database).filter_by(database_name="fake_db_100").one() + ) + old_expose_in_sqllab = fake_db.expose_in_sqllab + fake_db.expose_in_sqllab = False + db.session.commit() + self.login(username="admin") + arguments = { + "keys": ["none"], + "filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}], + "order_columns": "database_name", + "order_direction": "asc", + "page": 0, + "page_size": -1, + } + uri = f"api/v1/database/?q={prison.dumps(arguments)}" + rv = self.client.get(uri) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + self.assertEqual(response["count"], 1) + + fake_db = ( + db.session.query(Database).filter_by(database_name="fake_db_100").one() + ) + fake_db.expose_in_sqllab = old_expose_in_sqllab + db.session.commit() + + def test_get_items_not_allowed(self): + """ + Database API: Test get items not allowed + """ + self.login(username="gamma") + uri = f"api/v1/database/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["count"], 0) + + def test_get_table_metadata(self): + """ + Database API: Test get table metadata info + """ + example_db = get_example_database() + self.login(username="admin") + uri = f"api/v1/database/{example_db.id}/table/birth_names/null/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["name"], "birth_names") + self.assertTrue(len(response["columns"]) > 5) + self.assertTrue(response.get("selectStar").startswith("SELECT")) + + def test_get_invalid_database_table_metadata(self): + """ + Database API: Test get invalid database from table metadata + """ + database_id = 1000 + self.login(username="admin") + uri = f"api/v1/database/{database_id}/table/some_table/some_schema/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + uri = f"api/v1/database/some_database/table/some_table/some_schema/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_get_invalid_table_table_metadata(self): + """ + Database API: Test get invalid table from table metadata + """ + example_db = get_example_database() + uri = f"api/v1/database/{example_db.id}/wrong_table/null/" + self.login(username="admin") + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_get_table_metadata_no_db_permission(self): + """ + Database API: Test get table metadata from not permitted db + """ + self.login(username="gamma") + example_db = get_example_database() + uri = f"api/v1/database/{example_db.id}/birth_names/null/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404)