feat: supports mulitple filters in samples endpoint (#21008)

This commit is contained in:
Yongjie Zhao
2022-08-08 22:42:14 +08:00
committed by GitHub
parent e214e1ace6
commit 802b69f97b
6 changed files with 154 additions and 60 deletions

View File

@@ -432,14 +432,13 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
test_client.post(uri)
# get from cache
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv.status_code == 200
assert len(rv_data["result"]["data"]) == 10
assert len(rv.json["result"]["data"]) == 10
assert QueryCacheManager.has(
rv_data["result"]["cache_key"],
rv.json["result"]["cache_key"],
region=CacheRegion.DATA,
)
assert rv_data["result"]["is_cached"]
assert rv.json["result"]["is_cached"]
# 2. should read through cache data
uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
@@ -447,19 +446,18 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
test_client.post(uri2)
# force query
rv2 = test_client.post(uri2)
rv_data2 = json.loads(rv2.data)
assert rv2.status_code == 200
assert len(rv_data2["result"]["data"]) == 10
assert len(rv2.json["result"]["data"]) == 10
assert QueryCacheManager.has(
rv_data2["result"]["cache_key"],
rv2.json["result"]["cache_key"],
region=CacheRegion.DATA,
)
assert not rv_data2["result"]["is_cached"]
assert not rv2.json["result"]["is_cached"]
# 3. data precision
assert "colnames" in rv_data2["result"]
assert "coltypes" in rv_data2["result"]
assert "data" in rv_data2["result"]
assert "colnames" in rv2.json["result"]
assert "coltypes" in rv2.json["result"]
assert "data" in rv2.json["result"]
eager_samples = virtual_dataset.database.get_df(
f"select * from ({virtual_dataset.sql}) as tbl"
@@ -468,7 +466,7 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
# the col3 is Decimal
eager_samples["col3"] = eager_samples["col3"].apply(float)
eager_samples = eager_samples.to_dict(orient="records")
assert eager_samples == rv_data2["result"]["data"]
assert eager_samples == rv2.json["result"]["data"]
def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
@@ -486,10 +484,9 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
rv = test_client.post(uri)
assert rv.status_code == 422
rv_data = json.loads(rv.data)
assert "error" in rv_data
assert "error" in rv.json
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
assert "INCORRECT SQL" in rv_data.get("error")
assert "INCORRECT SQL" in rv.json.get("error")
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
@@ -498,11 +495,10 @@ def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_d
)
rv = test_client.post(uri)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert QueryCacheManager.has(
rv_data["result"]["cache_key"], region=CacheRegion.DATA
rv.json["result"]["cache_key"], region=CacheRegion.DATA
)
assert len(rv_data["result"]["data"]) == 10
assert len(rv.json["result"]["data"]) == 10
def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
@@ -533,9 +529,8 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
},
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
assert rv_data["result"]["rowcount"] == 1
assert rv.json["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
assert rv.json["result"]["rowcount"] == 1
# empty results
rv = test_client.post(
@@ -547,9 +542,64 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
},
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == []
assert rv_data["result"]["rowcount"] == 0
assert rv.json["result"]["colnames"] == []
assert rv.json["result"]["rowcount"] == 0
def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
)
payload = {
"granularity": "col5",
"time_range": "2000-01-02 : 2000-01-04",
}
rv = test_client.post(uri, json=payload)
assert len(rv.json["result"]["data"]) == 2
if physical_dataset.database.backend != "sqlite":
assert [row["col5"] for row in rv.json["result"]["data"]] == [
946771200000.0, # 2000-01-02 00:00:00
946857600000.0, # 2000-01-03 00:00:00
]
assert rv.json["result"]["page"] == 1
assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv.json["result"]["total_count"] == 2
def test_get_samples_with_multiple_filters(
test_client, login_as_admin, physical_dataset
):
# 1. empty response
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
)
payload = {
"granularity": "col5",
"time_range": "2000-01-02 : 2000-01-04",
"filters": [
{"col": "col4", "op": "IS NOT NULL"},
],
}
rv = test_client.post(uri, json=payload)
assert len(rv.json["result"]["data"]) == 0
# 2. adhoc filters, time filters, and custom where
payload = {
"granularity": "col5",
"time_range": "2000-01-02 : 2000-01-04",
"filters": [
{"col": "col2", "op": "==", "val": "c"},
],
"extras": {"where": "col3 = 1.2 and col4 is null"},
}
rv = test_client.post(uri, json=payload)
assert len(rv.json["result"]["data"]) == 1
assert rv.json["result"]["total_count"] == 1
assert "2000-01-02" in rv.json["result"]["query"]
assert "2000-01-04" in rv.json["result"]["query"]
assert "col3 = 1.2" in rv.json["result"]["query"]
assert "col4 is null" in rv.json["result"]["query"]
assert "col2 = 'c'" in rv.json["result"]["query"]
def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
@@ -558,10 +608,9 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 1
assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv_data["result"]["total_count"] == 10
assert rv.json["result"]["page"] == 1
assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv.json["result"]["total_count"] == 10
# 2. incorrect per_page
per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
@@ -582,25 +631,22 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
# 4. turning pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 1
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
assert rv.json["result"]["page"] == 1
assert rv.json["result"]["per_page"] == 2
assert rv.json["result"]["total_count"] == 10
assert [row["col1"] for row in rv.json["result"]["data"]] == [0, 1]
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 2
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
assert rv.json["result"]["page"] == 2
assert rv.json["result"]["per_page"] == 2
assert rv.json["result"]["total_count"] == 10
assert [row["col1"] for row in rv.json["result"]["data"]] == [2, 3]
# 5. Exceeding the maximum pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 6
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == []
assert rv.json["result"]["page"] == 6
assert rv.json["result"]["per_page"] == 2
assert rv.json["result"]["total_count"] == 10
assert [row["col1"] for row in rv.json["result"]["data"]] == []