mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat: the samples endpoint supports filters and pagination (#20683)
This commit is contained in:
@@ -14,11 +14,15 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from marshmallow import fields, post_load, Schema
|
||||
from marshmallow import fields, post_load, pre_load, Schema, validate
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from superset import app
|
||||
from superset.charts.schemas import ChartDataFilterSchema
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
|
||||
class ExternalMetadataParams(TypedDict):
|
||||
datasource_type: str
|
||||
@@ -54,3 +58,27 @@ class ExternalMetadataSchema(Schema):
|
||||
schema_name=data.get("schema_name", ""),
|
||||
table_name=data["table_name"],
|
||||
)
|
||||
|
||||
|
||||
class SamplesPayloadSchema(Schema):
|
||||
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
|
||||
|
||||
@pre_load
|
||||
# pylint: disable=no-self-use, unused-argument
|
||||
def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
|
||||
if data is None:
|
||||
return {}
|
||||
return data
|
||||
|
||||
|
||||
class SamplesRequestSchema(Schema):
|
||||
datasource_type = fields.String(
|
||||
validate=validate.OneOf([e.value for e in DatasourceType]), required=True
|
||||
)
|
||||
datasource_id = fields.Integer(required=True)
|
||||
force = fields.Boolean(load_default=False)
|
||||
page = fields.Integer(load_default=1)
|
||||
per_page = fields.Integer(
|
||||
validate=validate.Range(min=1, max=app.config.get("SAMPLES_ROW_LIMIT", 1000)),
|
||||
load_default=app.config.get("SAMPLES_ROW_LIMIT", 1000),
|
||||
)
|
||||
|
||||
115
superset/views/datasource/utils.py
Normal file
115
superset/views/datasource/utils.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from superset import app, db
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
from superset.common.utils.query_cache_manager import QueryCacheManager
|
||||
from superset.constants import CacheRegion
|
||||
from superset.datasets.commands.exceptions import DatasetSamplesFailedError
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.utils.core import QueryStatus
|
||||
from superset.views.datasource.schemas import SamplesPayloadSchema
|
||||
|
||||
|
||||
def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> Dict[str, int]:
|
||||
samples_row_limit = app.config.get("SAMPLES_ROW_LIMIT", 1000)
|
||||
limit = samples_row_limit
|
||||
offset = 0
|
||||
|
||||
if isinstance(page, int) and isinstance(per_page, int):
|
||||
limit = int(per_page)
|
||||
if limit < 0 or limit > samples_row_limit:
|
||||
# reset limit value if input is invalid
|
||||
limit = samples_row_limit
|
||||
|
||||
offset = max((int(page) - 1) * limit, 0)
|
||||
|
||||
return {"row_offset": offset, "row_limit": limit}
|
||||
|
||||
|
||||
def get_samples( # pylint: disable=too-many-arguments,too-many-locals
|
||||
datasource_type: str,
|
||||
datasource_id: int,
|
||||
force: bool = False,
|
||||
page: int = 1,
|
||||
per_page: int = 1000,
|
||||
payload: Optional[SamplesPayloadSchema] = None,
|
||||
) -> Dict[str, Any]:
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
session=db.session,
|
||||
datasource_type=datasource_type,
|
||||
datasource_id=datasource_id,
|
||||
)
|
||||
|
||||
limit_clause = get_limit_clause(page, per_page)
|
||||
|
||||
# todo(yongjie): Constructing count(*) and samples in the same query_context,
|
||||
# then remove query_type==SAMPLES
|
||||
# constructing samples query
|
||||
samples_instance = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": datasource.type,
|
||||
"id": datasource.id,
|
||||
},
|
||||
queries=[{**payload, **limit_clause} if payload else limit_clause],
|
||||
result_type=ChartDataResultType.SAMPLES,
|
||||
force=force,
|
||||
)
|
||||
|
||||
# constructing count(*) query
|
||||
count_star_metric = {
|
||||
"metrics": [
|
||||
{
|
||||
"expressionType": "SQL",
|
||||
"sqlExpression": "COUNT(*)",
|
||||
"label": "COUNT(*)",
|
||||
}
|
||||
]
|
||||
}
|
||||
count_star_instance = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": datasource.type,
|
||||
"id": datasource.id,
|
||||
},
|
||||
queries=[{**payload, **count_star_metric} if payload else count_star_metric],
|
||||
result_type=ChartDataResultType.FULL,
|
||||
force=force,
|
||||
)
|
||||
samples_results = samples_instance.get_payload()
|
||||
count_star_results = count_star_instance.get_payload()
|
||||
|
||||
try:
|
||||
sample_data = samples_results["queries"][0]
|
||||
count_star_data = count_star_results["queries"][0]
|
||||
failed_status = (
|
||||
sample_data.get("status") == QueryStatus.FAILED
|
||||
or count_star_data.get("status") == QueryStatus.FAILED
|
||||
)
|
||||
error_msg = sample_data.get("error") or count_star_data.get("error")
|
||||
if failed_status and error_msg:
|
||||
cache_key = sample_data.get("cache_key")
|
||||
QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
|
||||
raise DatasetSamplesFailedError(error_msg)
|
||||
|
||||
sample_data["page"] = page
|
||||
sample_data["per_page"] = per_page
|
||||
sample_data["total_count"] = count_star_data["data"][0]["COUNT(*)"]
|
||||
return sample_data
|
||||
except (IndexError, KeyError) as exc:
|
||||
raise DatasetSamplesFailedError from exc
|
||||
@@ -50,7 +50,10 @@ from superset.views.datasource.schemas import (
|
||||
ExternalMetadataParams,
|
||||
ExternalMetadataSchema,
|
||||
get_external_metadata_schema,
|
||||
SamplesPayloadSchema,
|
||||
SamplesRequestSchema,
|
||||
)
|
||||
from superset.views.datasource.utils import get_samples
|
||||
from superset.views.utils import sanitize_datasource_data
|
||||
|
||||
|
||||
@@ -179,3 +182,24 @@ class Datasource(BaseSupersetView):
|
||||
except (NoResultFound, NoSuchTableError) as ex:
|
||||
raise DatasetNotFoundError() from ex
|
||||
return self.json_response(external_metadata)
|
||||
|
||||
@expose("/samples", methods=["POST"])
|
||||
@has_access_api
|
||||
@api
|
||||
@handle_api_exception
|
||||
def samples(self) -> FlaskResponse:
|
||||
try:
|
||||
params = SamplesRequestSchema().load(request.args)
|
||||
payload = SamplesPayloadSchema().load(request.json)
|
||||
except ValidationError as err:
|
||||
return json_error_response(err.messages, status=400)
|
||||
|
||||
rv = get_samples(
|
||||
datasource_type=params["datasource_type"],
|
||||
datasource_id=params["datasource_id"],
|
||||
force=params["force"],
|
||||
page=params["page"],
|
||||
per_page=params["per_page"],
|
||||
payload=payload,
|
||||
)
|
||||
return self.json_response({"result": rv})
|
||||
|
||||
Reference in New Issue
Block a user