chore: enable lint PT009 'use regular assert over self.assert.*' (#30521)

This commit is contained in:
Maxime Beauchemin
2024-10-07 13:17:27 -07:00
committed by GitHub
parent 1f013055d2
commit a849c29288
62 changed files with 2218 additions and 2422 deletions

View File

@@ -70,15 +70,15 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context("birth_names", add_postprocessing_operations=True)
query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(len(query_context.queries), len(payload["queries"]))
assert len(query_context.queries) == len(payload["queries"])
for query_idx, query in enumerate(query_context.queries):
payload_query = payload["queries"][query_idx]
# check basic properties
self.assertEqual(query.extras, payload_query["extras"])
self.assertEqual(query.filter, payload_query["filters"])
self.assertEqual(query.columns, payload_query["columns"])
assert query.extras == payload_query["extras"]
assert query.filter == payload_query["filters"]
assert query.columns == payload_query["columns"]
# metrics are mutated during creation
for metric_idx, metric in enumerate(query.metrics):
@@ -88,16 +88,16 @@ class TestQueryContext(SupersetTestCase):
if "expressionType" in payload_metric
else payload_metric["label"]
)
self.assertEqual(metric, payload_metric)
assert metric == payload_metric
self.assertEqual(query.orderby, payload_query["orderby"])
self.assertEqual(query.time_range, payload_query["time_range"])
assert query.orderby == payload_query["orderby"]
assert query.time_range == payload_query["time_range"]
# check post processing operation properties
for post_proc_idx, post_proc in enumerate(query.post_processing):
payload_post_proc = payload_query["post_processing"][post_proc_idx]
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
self.assertEqual(post_proc["options"], payload_post_proc["options"])
assert post_proc["operation"] == payload_post_proc["operation"]
assert post_proc["options"] == payload_post_proc["options"]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_cache(self):
@@ -128,12 +128,12 @@ class TestQueryContext(SupersetTestCase):
rehydrated_qo = rehydrated_qc.queries[0]
rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo)
self.assertEqual(rehydrated_qc.datasource, query_context.datasource)
self.assertEqual(len(rehydrated_qc.queries), 1)
self.assertEqual(query_cache_key, rehydrated_query_cache_key)
self.assertEqual(rehydrated_qc.result_type, query_context.result_type)
self.assertEqual(rehydrated_qc.result_format, query_context.result_format)
self.assertFalse(rehydrated_qc.force)
assert rehydrated_qc.datasource == query_context.datasource
assert len(rehydrated_qc.queries) == 1
assert query_cache_key == rehydrated_query_cache_key
assert rehydrated_qc.result_type == query_context.result_type
assert rehydrated_qc.result_format == query_context.result_format
assert not rehydrated_qc.force
def test_query_cache_key_changes_when_datasource_is_updated(self):
payload = get_query_context("birth_names")
@@ -164,7 +164,7 @@ class TestQueryContext(SupersetTestCase):
cache_key_new = query_context.query_cache_key(query_object)
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
assert cache_key_original != cache_key_new
def test_query_cache_key_changes_when_metric_is_updated(self):
payload = get_query_context("birth_names")
@@ -198,7 +198,7 @@ class TestQueryContext(SupersetTestCase):
db.session.commit()
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
assert cache_key_original != cache_key_new
def test_query_cache_key_does_not_change_for_non_existent_or_null(self):
payload = get_query_context("birth_names", add_postprocessing_operations=True)
@@ -228,14 +228,14 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key = query_context.query_cache_key(query_object)
self.assertEqual(cache_key_original, cache_key)
assert cache_key_original == cache_key
# ensure query without post processing operation is different
payload["queries"][0].pop("post_processing")
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key = query_context.query_cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key)
assert cache_key_original != cache_key
def test_query_cache_key_changes_when_time_offsets_is_updated(self):
payload = get_query_context("birth_names", add_time_offsets=True)
@@ -248,7 +248,7 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key = query_context.query_cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key)
assert cache_key_original != cache_key
def test_handle_metrics_field(self):
"""
@@ -265,7 +265,7 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["metrics"] = ["sum__num", {"label": "abc"}, adhoc_metric]
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
self.assertEqual(query_object.metrics, ["sum__num", "abc", adhoc_metric])
assert query_object.metrics == ["sum__num", "abc", adhoc_metric]
def test_convert_deprecated_fields(self):
"""
@@ -280,12 +280,12 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["granularity_sqla"] = "timecol"
payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(len(query_context.queries), 1)
assert len(query_context.queries) == 1
query_object = query_context.queries[0]
self.assertEqual(query_object.granularity, "timecol")
self.assertEqual(query_object.columns, columns)
self.assertEqual(query_object.series_limit, 99)
self.assertEqual(query_object.series_limit_metric, "sum__num")
assert query_object.granularity == "timecol"
assert query_object.columns == columns
assert query_object.series_limit == 99
assert query_object.series_limit_metric == "sum__num"
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_csv_response_format(self):
@@ -297,10 +297,10 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["row_limit"] = 10
query_context: QueryContext = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
assert len(responses) == 1
data = responses["queries"][0]["data"]
self.assertIn("name,sum__num\n", data)
self.assertEqual(len(data.split("\n")), 12)
assert "name,sum__num\n" in data
assert len(data.split("\n")) == 12
def test_sql_injection_via_groupby(self):
"""
@@ -352,11 +352,11 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["row_limit"] = 5
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
assert len(responses) == 1
data = responses["queries"][0]["data"]
self.assertIsInstance(data, list)
self.assertEqual(len(data), 5)
self.assertNotIn("sum__num", data[0])
assert isinstance(data, list)
assert len(data) == 5
assert "sum__num" not in data[0]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_response_type(self):
@@ -489,7 +489,7 @@ class TestQueryContext(SupersetTestCase):
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
new_cache_key = responses["queries"][0]["cache_key"]
self.assertEqual(orig_cache_key, new_cache_key)
assert orig_cache_key == new_cache_key
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_time_offsets_in_query_object(self):
@@ -505,21 +505,18 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["time_range"] = "1990 : 1991"
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(
responses["queries"][0]["colnames"],
[
"__timestamp",
"name",
"sum__num",
"sum__num__1 year ago",
"sum__num__1 year later",
],
)
assert responses["queries"][0]["colnames"] == [
"__timestamp",
"name",
"sum__num",
"sum__num__1 year ago",
"sum__num__1 year later",
]
sqls = [
sql for sql in responses["queries"][0]["query"].split(";") if sql.strip()
]
self.assertEqual(len(sqls), 3)
assert len(sqls) == 3
# 1 year ago
assert re.search(r"1989-01-01.+1990-01-01", sqls[1], re.S)
assert re.search(r"1990-01-01.+1991-01-01", sqls[1], re.S)
@@ -560,9 +557,9 @@ class TestQueryContext(SupersetTestCase):
cache_keys = rv["cache_keys"]
cache_keys__1_year_ago = cache_keys[0]
cache_keys__1_year_later = cache_keys[1]
self.assertIsNotNone(cache_keys__1_year_ago)
self.assertIsNotNone(cache_keys__1_year_later)
self.assertNotEqual(cache_keys__1_year_ago, cache_keys__1_year_later)
assert cache_keys__1_year_ago is not None
assert cache_keys__1_year_later is not None
assert cache_keys__1_year_ago != cache_keys__1_year_later
# swap offsets
payload["queries"][0]["time_offsets"] = ["1 year later", "1 year ago"]
@@ -570,8 +567,8 @@ class TestQueryContext(SupersetTestCase):
query_object = query_context.queries[0]
rv = query_context.processing_time_offsets(df.copy(), query_object)
cache_keys = rv["cache_keys"]
self.assertEqual(cache_keys__1_year_ago, cache_keys[1])
self.assertEqual(cache_keys__1_year_later, cache_keys[0])
assert cache_keys__1_year_ago == cache_keys[1]
assert cache_keys__1_year_later == cache_keys[0]
# remove all offsets
payload["queries"][0]["time_offsets"] = []
@@ -582,9 +579,9 @@ class TestQueryContext(SupersetTestCase):
query_object,
)
self.assertEqual(rv["df"].shape, df.shape)
self.assertEqual(rv["queries"], [])
self.assertEqual(rv["cache_keys"], [])
assert rv["df"].shape == df.shape
assert rv["queries"] == []
assert rv["cache_keys"] == []
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_time_offsets_sql(self):
@@ -732,7 +729,7 @@ class TestQueryContext(SupersetTestCase):
row_limit_pattern_with_config_value = r"LIMIT " + re.escape(
str(row_limit_value)
)
self.assertEqual(len(sqls), 2)
assert len(sqls) == 2
# 1 year ago
assert re.search(r"1989-01-01.+1990-01-01", sqls[0], re.S)
assert not re.search(r"LIMIT 100", sqls[0], re.S)