diff --git a/docs/static/resources/openapi.json b/docs/static/resources/openapi.json index bb8858fb6cc..831b120bf05 100644 --- a/docs/static/resources/openapi.json +++ b/docs/static/resources/openapi.json @@ -962,8 +962,11 @@ "ChartDataDatasource": { "properties": { "id": { - "description": "Datasource id", - "type": "integer" + "description": "Datasource id/uuid", + "oneOf": [ + { "type": "integer" }, + { "type": "string" } + ] }, "type": { "description": "Datasource type", diff --git a/pyproject.toml b/pyproject.toml index 1fea1dae741..11d031d2b92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ dependencies = [ "markdown>=3.0", # marshmallow>=4 has issues: https://github.com/apache/superset/issues/33162 "marshmallow>=3.0, <4", + "marshmallow-union>=0.1", "msgpack>=1.0.0, <1.1", "nh3>=0.2.11, <0.3", "numpy>1.23.5, <2.3", @@ -227,7 +228,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = "superset, apache-superset-core, apache-superset-extensions-cli" -known_third_party = "alembic, apispec, backoff, celery, click, colorama, cron_descriptor, croniter, cryptography, dateutil, deprecation, flask, flask_appbuilder, flask_babel, flask_caching, flask_compress, flask_jwt_extended, flask_login, flask_migrate, flask_sqlalchemy, flask_talisman, flask_testing, flask_wtf, freezegun, geohash, geopy, holidays, humanize, isodate, jinja2, jwt, markdown, markupsafe, marshmallow, msgpack, nh3, numpy, pandas, parameterized, parsedatetime, pgsanity, polyline, prison, progress, pyarrow, sqlalchemy_bigquery, pyhive, pyparsing, pytest, pytest_mock, pytz, redis, requests, selenium, setuptools, shillelagh, simplejson, slack, sqlalchemy, sqlalchemy_utils, typing_extensions, urllib3, werkzeug, wtforms, wtforms_json, yaml" +known_third_party = "alembic, apispec, backoff, celery, click, colorama, cron_descriptor, croniter, cryptography, dateutil, deprecation, flask, flask_appbuilder, flask_babel, flask_caching, flask_compress, flask_jwt_extended, flask_login, flask_migrate, flask_sqlalchemy, flask_talisman, flask_testing, flask_wtf, freezegun, geohash, geopy, holidays, humanize, isodate, jinja2, jwt, markdown, markupsafe, marshmallow, marshmallow-union, msgpack, nh3, numpy, pandas, parameterized, parsedatetime, pgsanity, polyline, prison, progress, pyarrow, sqlalchemy_bigquery, pyhive, pyparsing, pytest, pytest_mock, pytz, redis, requests, selenium, setuptools, shillelagh, simplejson, slack, sqlalchemy, sqlalchemy_utils, typing_extensions, urllib3, werkzeug, wtforms, wtforms_json, yaml" multi_line_output = 3 order_by_type = false diff --git a/requirements/base.txt b/requirements/base.txt index 84934fd421f..e733e594aa1 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -158,6 +158,7 @@ greenlet==3.1.1 # via # apache-superset (pyproject.toml) # shillelagh + # sqlalchemy gunicorn==23.0.0 # via apache-superset (pyproject.toml) h11==0.16.0 @@ -219,10 +220,13 @@ marshmallow==3.26.1 # apache-superset (pyproject.toml) # flask-appbuilder # marshmallow-sqlalchemy + # marshmallow-union marshmallow-sqlalchemy==1.4.0 # via # -r requirements/base.in # flask-appbuilder +marshmallow-union==0.1.15 + # via apache-superset (pyproject.toml) mdurl==0.1.2 # via markdown-it-py msgpack==1.0.8 diff --git a/requirements/development.txt b/requirements/development.txt index cdd80f47b30..5dd5e689f86 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -327,6 +327,7 @@ greenlet==3.1.1 # apache-superset # gevent # shillelagh + # sqlalchemy grpcio==1.71.0 # via # apache-superset @@ -444,10 +445,15 @@ marshmallow==3.26.1 # apache-superset # flask-appbuilder # marshmallow-sqlalchemy + # marshmallow-union marshmallow-sqlalchemy==1.4.0 # via # -c requirements/base-constraint.txt # flask-appbuilder +marshmallow-union==0.1.15 + # via + # -c requirements/base.txt + # apache-superset matplotlib==3.9.0 # via prophet mccabe==0.7.0 diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 3a05c3ff176..e794f180910 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -24,6 +24,7 @@ from flask import current_app from flask_babel import gettext as _ from marshmallow import EXCLUDE, fields, post_load, Schema, validate from marshmallow.validate import Length, Range +from marshmallow_union import Union from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.db_engine_specs.base import builtin_time_grains @@ -1130,8 +1131,9 @@ class AnnotationLayerSchema(Schema): class ChartDataDatasourceSchema(Schema): description = "Chart datasource" - id = fields.Integer( - metadata={"description": "Datasource id"}, + id = Union( + [fields.Integer(), fields.UUID()], + metadata={"description": "Datasource id or uuid"}, required=True, ) type = fields.String( diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 9205ba418ef..81f460c08b4 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -106,7 +106,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: return DatasourceDAO.get_datasource( datasource_type=DatasourceType(datasource["type"]), - datasource_id=int(datasource["id"]), + database_id_or_uuid=datasource["id"], ) def _get_slice(self, slice_id: Any) -> Slice | None: diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index 063ba023481..31d2843da56 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -91,7 +91,7 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: return self._datasource_dao.get_datasource( datasource_type=DatasourceType(datasource["type"]), - datasource_id=int(datasource["id"]), + database_id_or_uuid=datasource["id"], ) def _process_extras( diff --git a/superset/daos/base.py b/superset/daos/base.py index 0f8d8f388fc..2b40dc0e333 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -47,12 +47,41 @@ class BaseDAO(Generic[T]): Child classes can register base filtering to be applied to all filter methods """ id_column_name = "id" + uuid_column_name = "uuid" def __init_subclass__(cls) -> None: cls.model_cls = get_args( cls.__orig_bases__[0] # type: ignore # pylint: disable=no-member )[0] + @classmethod + def find_by_id_or_uuid( + cls, + model_id_or_uuid: str, + skip_base_filter: bool = False, + ) -> T | None: + """ + Find a model by id or uuid, if defined applies `base_filter` + """ + query = db.session.query(cls.model_cls) + if cls.base_filter and not skip_base_filter: + data_model = SQLAInterface(cls.model_cls, db.session) + query = cls.base_filter( # pylint: disable=not-callable + cls.id_column_name, data_model + ).apply(query, None) + id_column = getattr(cls.model_cls, cls.id_column_name) + uuid_column = getattr(cls.model_cls, cls.uuid_column_name) + + if model_id_or_uuid.isdigit(): + filter = id_column == int(model_id_or_uuid) + else: + filter = uuid_column == model_id_or_uuid + try: + return query.filter(filter).one_or_none() + except StatementError: + # can happen if neither uuid nor int is passed + return None + @classmethod def find_by_id( cls, diff --git a/superset/daos/datasource.py b/superset/daos/datasource.py index 9a48bb86b7b..c754fd8e1f5 100644 --- a/superset/daos/datasource.py +++ b/superset/daos/datasource.py @@ -16,12 +16,17 @@ # under the License. import logging +import uuid from typing import Union from superset import db from superset.connectors.sqla.models import SqlaTable from superset.daos.base import BaseDAO -from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError +from superset.daos.exceptions import ( + DatasourceNotFound, + DatasourceTypeNotSupportedError, + DatasourceValueIsIncorrect, +) from superset.models.sql_lab import Query, SavedQuery from superset.utils.core import DatasourceType @@ -41,22 +46,34 @@ class DatasourceDAO(BaseDAO[Datasource]): def get_datasource( cls, datasource_type: Union[DatasourceType, str], - datasource_id: int, + database_id_or_uuid: int | str, ) -> Datasource: if datasource_type not in cls.sources: raise DatasourceTypeNotSupportedError() + model = cls.sources[datasource_type] + + if str(database_id_or_uuid).isdigit(): + filter = model.id == int(database_id_or_uuid) + else: + try: + uuid.UUID(str(database_id_or_uuid)) # uuid validation + filter = model.uuid == database_id_or_uuid + except ValueError as err: + logger.warning( + f"database_id_or_uuid {database_id_or_uuid} isn't valid uuid" + ) + raise DatasourceValueIsIncorrect() from err + datasource = ( - db.session.query(cls.sources[datasource_type]) - .filter_by(id=datasource_id) - .one_or_none() + db.session.query(cls.sources[datasource_type]).filter(filter).one_or_none() ) if not datasource: logger.warning( - "Datasource not found datasource_type: %s, datasource_id: %s", + "Datasource not found datasource_type: %s, database_id_or_uuid: %s", datasource_type, - datasource_id, + database_id_or_uuid, ) raise DatasourceNotFound() diff --git a/superset/daos/exceptions.py b/superset/daos/exceptions.py index 1b9fdf606d9..fb132546b0b 100644 --- a/superset/daos/exceptions.py +++ b/superset/daos/exceptions.py @@ -68,3 +68,8 @@ class DatasourceTypeNotSupportedError(DAOException): class DatasourceNotFound(DAOException): status = 404 message = "Datasource does not exist" + + +class DatasourceValueIsIncorrect(DAOException): + status = 422 + message = "Datasource value is neither id or uuid" diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 2afd0411c28..6821b837e3c 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -702,7 +702,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): ) return self.response_422(message=str(ex)) - @expose("//related_objects", methods=("GET",)) + @expose("//related_objects", methods=("GET",)) @protect() @safe @statsd_metrics @@ -711,16 +711,16 @@ class DatasetRestApi(BaseSupersetModelRestApi): f".related_objects", log_to_statsd=False, ) - def related_objects(self, pk: int) -> Response: + def related_objects(self, id_or_uuid: str) -> Response: """Get charts and dashboards count associated to a dataset. --- get: summary: Get charts and dashboards count associated to a dataset parameters: - in: path - name: pk + name: id_or_uuid schema: - type: integer + type: string responses: 200: 200: @@ -736,10 +736,10 @@ class DatasetRestApi(BaseSupersetModelRestApi): 500: $ref: '#/components/responses/500' """ - dataset = DatasetDAO.find_by_id(pk) + dataset = DatasetDAO.find_by_id_or_uuid(id_or_uuid) if not dataset: return self.response_404() - data = DatasetDAO.get_related_objects(pk) + data = DatasetDAO.get_related_objects(dataset.id) charts = [ { "id": chart.id, @@ -1081,7 +1081,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): except CommandException as ex: return self.response(ex.status, message=ex.message) - @expose("/", methods=("GET",)) + @expose("/", methods=("GET",)) @protect() @safe @rison(get_item_schema) @@ -1091,7 +1091,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get", log_to_statsd=False, ) - def get(self, pk: int, **kwargs: Any) -> Response: + def get(self, id_or_uuid: str, **kwargs: Any) -> Response: """Get a dataset. --- get: @@ -1100,9 +1100,9 @@ class DatasetRestApi(BaseSupersetModelRestApi): parameters: - in: path schema: - type: integer - description: The dataset ID - name: pk + type: string + description: Either the id of the dataset, or its uuid + name: id_or_uuid - in: query name: q content: @@ -1138,13 +1138,8 @@ class DatasetRestApi(BaseSupersetModelRestApi): 500: $ref: '#/components/responses/500' """ - item: SqlaTable | None = self.datamodel.get( - pk, - self._base_filters, - self.show_select_columns, - self.show_outer_default_load, - ) - if not item: + table = DatasetDAO.find_by_id_or_uuid(id_or_uuid) + if not table: return self.response_404() response: dict[str, Any] = {} @@ -1162,8 +1157,8 @@ class DatasetRestApi(BaseSupersetModelRestApi): else: show_model_schema = self.show_model_schema - response["id"] = pk - response[API_RESULT_RES_KEY] = show_model_schema.dump(item, many=False) + response["id"] = table.id + response[API_RESULT_RES_KEY] = show_model_schema.dump(table, many=False) # remove folders from resposne if `DATASET_FOLDERS` is disabled, so that it's # possible to inspect if the feature is supported or not @@ -1175,12 +1170,13 @@ class DatasetRestApi(BaseSupersetModelRestApi): if parse_boolean_string(request.args.get("include_rendered_sql")): try: - processor = get_template_processor(database=item.database) + processor = get_template_processor(database=table.database) response["result"] = self.render_dataset_fields( response["result"], processor ) except SupersetTemplateException as ex: return self.response_400(message=str(ex)) + return self.response(200, **response) @expose("//drill_info/", methods=("GET",)) diff --git a/superset/utils/core.py b/superset/utils/core.py index 3185148ff86..bc758f1899a 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -218,7 +218,7 @@ class HeaderDataType(TypedDict): class DatasourceDict(TypedDict): type: str # todo(hugh): update this to be DatasourceType - id: int + id: int | str class AdhocFilterClause(TypedDict, total=False): diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index 75c9eb7b0d3..e0751bb19ab 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -54,7 +54,7 @@ def get_samples( # pylint: disable=too-many-arguments ) -> dict[str, Any]: datasource = DatasourceDAO.get_datasource( datasource_type=datasource_type, - datasource_id=datasource_id, + database_id_or_uuid=str(datasource_id), ) limit_clause = get_limit_clause(page, per_page) diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index c57ea11419d..751f96631ea 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -153,7 +153,7 @@ class TestQueryContext(SupersetTestCase): # make temporary change and revert it to refresh the changed_on property datasource = DatasourceDAO.get_datasource( datasource_type=DatasourceType(payload["datasource"]["type"]), - datasource_id=payload["datasource"]["id"], + database_id_or_uuid=payload["datasource"]["id"], ) description_original = datasource.description datasource.description = "temporary description" diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py index bf486010abf..17e170f12b0 100644 --- a/tests/unit_tests/datasource/dao_tests.py +++ b/tests/unit_tests/datasource/dao_tests.py @@ -76,7 +76,7 @@ def test_get_datasource_sqlatable(session_with_data: Session) -> None: result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.TABLE, - datasource_id=1, + database_id_or_uuid=1, ) assert 1 == result.id @@ -89,7 +89,7 @@ def test_get_datasource_query(session_with_data: Session) -> None: from superset.models.sql_lab import Query result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.QUERY, datasource_id=1 + datasource_type=DatasourceType.QUERY, database_id_or_uuid=1 ) assert result.id == 1 @@ -102,7 +102,7 @@ def test_get_datasource_saved_query(session_with_data: Session) -> None: result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.SAVEDQUERY, - datasource_id=1, + database_id_or_uuid=1, ) assert result.id == 1 @@ -116,7 +116,7 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None: assert isinstance( DatasourceDAO.get_datasource( datasource_type="table", - datasource_id=1, + database_id_or_uuid=1, ), SqlaTable, ) @@ -136,5 +136,5 @@ def test_not_found_datasource(session_with_data: Session) -> None: with pytest.raises(DatasourceNotFound): DatasourceDAO.get_datasource( datasource_type="table", - datasource_id=500000, + database_id_or_uuid=500000, )