mirror of
https://github.com/apache/superset.git
synced 2026-04-14 05:34:38 +00:00
feat: the samples endpoint supports filters and pagination (#20683)
This commit is contained in:
@@ -23,13 +23,15 @@ import prison
|
||||
import pytest
|
||||
|
||||
from superset import app, db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.common.utils.query_cache_manager import QueryCacheManager
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.constants import CacheRegion
|
||||
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
|
||||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
from superset.exceptions import SupersetGenericDBErrorException
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import DatasourceType, get_example_default_schema
|
||||
from superset.utils.database import get_example_database
|
||||
from superset.utils.core import backend, get_example_default_schema
|
||||
from superset.utils.database import get_example_database, get_main_database
|
||||
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
@@ -416,3 +418,189 @@ class TestDatasource(SupersetTestCase):
|
||||
self.login(username="admin")
|
||||
resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
|
||||
self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType")
|
||||
|
||||
|
||||
def test_get_samples(test_client, login_as_admin, virtual_dataset):
|
||||
"""
|
||||
Dataset API: Test get dataset samples
|
||||
"""
|
||||
# 1. should cache data
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
|
||||
)
|
||||
# feeds data
|
||||
test_client.post(uri)
|
||||
# get from cache
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv.status_code == 200
|
||||
assert len(rv_data["result"]["data"]) == 10
|
||||
assert QueryCacheManager.has(
|
||||
rv_data["result"]["cache_key"],
|
||||
region=CacheRegion.DATA,
|
||||
)
|
||||
assert rv_data["result"]["is_cached"]
|
||||
|
||||
# 2. should read through cache data
|
||||
uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
|
||||
# feeds data
|
||||
test_client.post(uri2)
|
||||
# force query
|
||||
rv2 = test_client.post(uri2)
|
||||
rv_data2 = json.loads(rv2.data)
|
||||
assert rv2.status_code == 200
|
||||
assert len(rv_data2["result"]["data"]) == 10
|
||||
assert QueryCacheManager.has(
|
||||
rv_data2["result"]["cache_key"],
|
||||
region=CacheRegion.DATA,
|
||||
)
|
||||
assert not rv_data2["result"]["is_cached"]
|
||||
|
||||
# 3. data precision
|
||||
assert "colnames" in rv_data2["result"]
|
||||
assert "coltypes" in rv_data2["result"]
|
||||
assert "data" in rv_data2["result"]
|
||||
|
||||
eager_samples = virtual_dataset.database.get_df(
|
||||
f"select * from ({virtual_dataset.sql}) as tbl"
|
||||
f' limit {app.config["SAMPLES_ROW_LIMIT"]}'
|
||||
)
|
||||
# the col3 is Decimal
|
||||
eager_samples["col3"] = eager_samples["col3"].apply(float)
|
||||
eager_samples = eager_samples.to_dict(orient="records")
|
||||
assert eager_samples == rv_data2["result"]["data"]
|
||||
|
||||
|
||||
def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
|
||||
TableColumn(
|
||||
column_name="DUMMY CC",
|
||||
type="VARCHAR(255)",
|
||||
table=virtual_dataset,
|
||||
expression="INCORRECT SQL",
|
||||
)
|
||||
db.session.merge(virtual_dataset)
|
||||
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
|
||||
)
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 422
|
||||
|
||||
rv_data = json.loads(rv.data)
|
||||
assert "error" in rv_data
|
||||
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
|
||||
assert "INCORRECT SQL" in rv_data.get("error")
|
||||
|
||||
|
||||
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
|
||||
)
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 200
|
||||
rv_data = json.loads(rv.data)
|
||||
assert QueryCacheManager.has(
|
||||
rv_data["result"]["cache_key"], region=CacheRegion.DATA
|
||||
)
|
||||
assert len(rv_data["result"]["data"]) == 10
|
||||
|
||||
|
||||
def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
|
||||
)
|
||||
rv = test_client.post(uri, json=None)
|
||||
assert rv.status_code == 200
|
||||
|
||||
rv = test_client.post(uri, json={})
|
||||
assert rv.status_code == 200
|
||||
|
||||
rv = test_client.post(uri, json={"foo": "bar"})
|
||||
assert rv.status_code == 400
|
||||
|
||||
rv = test_client.post(
|
||||
uri, json={"filters": [{"col": "col1", "op": "INVALID", "val": 0}]}
|
||||
)
|
||||
assert rv.status_code == 400
|
||||
|
||||
rv = test_client.post(
|
||||
uri,
|
||||
json={
|
||||
"filters": [
|
||||
{"col": "col2", "op": "==", "val": "a"},
|
||||
{"col": "col1", "op": "==", "val": 0},
|
||||
]
|
||||
},
|
||||
)
|
||||
assert rv.status_code == 200
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
|
||||
assert rv_data["result"]["rowcount"] == 1
|
||||
|
||||
# empty results
|
||||
rv = test_client.post(
|
||||
uri,
|
||||
json={
|
||||
"filters": [
|
||||
{"col": "col2", "op": "==", "val": "x"},
|
||||
]
|
||||
},
|
||||
)
|
||||
assert rv.status_code == 200
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["colnames"] == []
|
||||
assert rv_data["result"]["rowcount"] == 0
|
||||
|
||||
|
||||
def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
|
||||
# 1. default page, per_page and total_count
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
|
||||
)
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 1
|
||||
assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
|
||||
# 2. incorrect per_page
|
||||
per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
|
||||
for per_page in per_pages:
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page={per_page}"
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 400
|
||||
|
||||
# 3. incorrect page or datasource_type
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&page=xx"
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 400
|
||||
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=xx"
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 400
|
||||
|
||||
# 4. turning pages
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 1
|
||||
assert rv_data["result"]["per_page"] == 2
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
|
||||
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 2
|
||||
assert rv_data["result"]["per_page"] == 2
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
|
||||
|
||||
# 5. Exceeding the maximum pages
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 6
|
||||
assert rv_data["result"]["per_page"] == 2
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv_data["result"]["data"]] == []
|
||||
|
||||
Reference in New Issue
Block a user