feat: the samples endpoint supports filters and pagination (#20683)

This commit is contained in:
Yongjie Zhao
2022-07-22 20:14:42 +08:00
committed by GitHub
parent 39545352d2
commit f011abae2b
13 changed files with 479 additions and 437 deletions

View File

@@ -23,13 +23,15 @@ import prison
import pytest
from superset import app, db
from superset.connectors.sqla.models import SqlaTable
from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.constants import CacheRegion
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetGenericDBErrorException
from superset.models.core import Database
from superset.utils.core import DatasourceType, get_example_default_schema
from superset.utils.database import get_example_database
from superset.utils.core import backend, get_example_default_schema
from superset.utils.database import get_example_database, get_main_database
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@@ -416,3 +418,189 @@ class TestDatasource(SupersetTestCase):
self.login(username="admin")
resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType")
def test_get_samples(test_client, login_as_admin, virtual_dataset):
"""
Dataset API: Test get dataset samples
"""
# 1. should cache data
uri = (
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
# feeds data
test_client.post(uri)
# get from cache
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv.status_code == 200
assert len(rv_data["result"]["data"]) == 10
assert QueryCacheManager.has(
rv_data["result"]["cache_key"],
region=CacheRegion.DATA,
)
assert rv_data["result"]["is_cached"]
# 2. should read through cache data
uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
# feeds data
test_client.post(uri2)
# force query
rv2 = test_client.post(uri2)
rv_data2 = json.loads(rv2.data)
assert rv2.status_code == 200
assert len(rv_data2["result"]["data"]) == 10
assert QueryCacheManager.has(
rv_data2["result"]["cache_key"],
region=CacheRegion.DATA,
)
assert not rv_data2["result"]["is_cached"]
# 3. data precision
assert "colnames" in rv_data2["result"]
assert "coltypes" in rv_data2["result"]
assert "data" in rv_data2["result"]
eager_samples = virtual_dataset.database.get_df(
f"select * from ({virtual_dataset.sql}) as tbl"
f' limit {app.config["SAMPLES_ROW_LIMIT"]}'
)
# the col3 is Decimal
eager_samples["col3"] = eager_samples["col3"].apply(float)
eager_samples = eager_samples.to_dict(orient="records")
assert eager_samples == rv_data2["result"]["data"]
def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
TableColumn(
column_name="DUMMY CC",
type="VARCHAR(255)",
table=virtual_dataset,
expression="INCORRECT SQL",
)
db.session.merge(virtual_dataset)
uri = (
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri)
assert rv.status_code == 422
rv_data = json.loads(rv.data)
assert "error" in rv_data
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
assert "INCORRECT SQL" in rv_data.get("error")
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert QueryCacheManager.has(
rv_data["result"]["cache_key"], region=CacheRegion.DATA
)
assert len(rv_data["result"]["data"]) == 10
def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
uri = (
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri, json=None)
assert rv.status_code == 200
rv = test_client.post(uri, json={})
assert rv.status_code == 200
rv = test_client.post(uri, json={"foo": "bar"})
assert rv.status_code == 400
rv = test_client.post(
uri, json={"filters": [{"col": "col1", "op": "INVALID", "val": 0}]}
)
assert rv.status_code == 400
rv = test_client.post(
uri,
json={
"filters": [
{"col": "col2", "op": "==", "val": "a"},
{"col": "col1", "op": "==", "val": 0},
]
},
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
assert rv_data["result"]["rowcount"] == 1
# empty results
rv = test_client.post(
uri,
json={
"filters": [
{"col": "col2", "op": "==", "val": "x"},
]
},
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == []
assert rv_data["result"]["rowcount"] == 0
def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
# 1. default page, per_page and total_count
uri = (
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 1
assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv_data["result"]["total_count"] == 10
# 2. incorrect per_page
per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
for per_page in per_pages:
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page={per_page}"
rv = test_client.post(uri)
assert rv.status_code == 400
# 3. incorrect page or datasource_type
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&page=xx"
rv = test_client.post(uri)
assert rv.status_code == 400
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=xx"
rv = test_client.post(uri)
assert rv.status_code == 400
# 4. turning pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 1
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 2
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
# 5. Exceeding the maximum pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 6
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == []