chore: enable lint PT009 'use regular assert over self.assert.*' (#30521)

This commit is contained in:
Maxime Beauchemin
2024-10-07 13:17:27 -07:00
committed by GitHub
parent 1f013055d2
commit a849c29288
62 changed files with 2218 additions and 2422 deletions

View File

@@ -22,11 +22,11 @@ class TestAscendDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
AscendEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
assert (
AscendEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"
)
self.assertEqual(
AscendEngineSpec.convert_dttm("TIMESTAMP", dttm),
"CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
assert (
AscendEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)"
)

View File

@@ -61,18 +61,18 @@ class TestDbEngineSpecs(TestDbEngineSpec):
q10 = "select * from mytable limit 20, x"
q11 = "select * from mytable limit x offset 20"
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
assert engine_spec_class.get_limit_from_sql(q0) is None
assert engine_spec_class.get_limit_from_sql(q1) == 10
assert engine_spec_class.get_limit_from_sql(q2) == 20
assert engine_spec_class.get_limit_from_sql(q3) is None
assert engine_spec_class.get_limit_from_sql(q4) == 20
assert engine_spec_class.get_limit_from_sql(q5) == 10
assert engine_spec_class.get_limit_from_sql(q6) == 10
assert engine_spec_class.get_limit_from_sql(q7) is None
assert engine_spec_class.get_limit_from_sql(q8) is None
assert engine_spec_class.get_limit_from_sql(q9) is None
assert engine_spec_class.get_limit_from_sql(q10) is None
assert engine_spec_class.get_limit_from_sql(q11) is None
def test_wrapped_semi_tabs(self):
self.sql_limit_regex(
@@ -141,7 +141,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
)
def test_get_datatype(self):
self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
assert "VARCHAR" == BaseEngineSpec.get_datatype("VARCHAR")
def test_limit_with_implicit_offset(self):
self.sql_limit_regex(
@@ -198,29 +198,26 @@ class TestDbEngineSpecs(TestDbEngineSpec):
for engine in load_engine_specs():
if engine is not BaseEngineSpec:
# make sure time grain functions have been defined
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
assert len(engine.get_time_grain_expressions()) > 0
# make sure all defined time grains are supported
defined_grains = {grain.duration for grain in engine.get_time_grains()}
intersection = time_grains.intersection(defined_grains)
self.assertSetEqual(defined_grains, intersection, engine)
self.assertSetEqual(defined_grains, intersection, engine) # noqa: PT009
def test_get_time_grain_expressions(self):
time_grains = MySQLEngineSpec.get_time_grain_expressions()
self.assertEqual(
list(time_grains.keys()),
[
None,
"PT1S",
"PT1M",
"PT1H",
"P1D",
"P1W",
"P1M",
"P3M",
"P1Y",
"1969-12-29T00:00:00Z/P1W",
],
)
assert list(time_grains.keys()) == [
None,
"PT1S",
"PT1M",
"PT1H",
"P1D",
"P1W",
"P1M",
"P3M",
"P1Y",
"1969-12-29T00:00:00Z/P1W",
]
def test_get_table_names(self):
inspector = mock.Mock()
@@ -255,11 +252,11 @@ class TestDbEngineSpecs(TestDbEngineSpec):
expected = ["STRING", "STRING", "FLOAT"]
else:
expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
self.assertEqual(col_names, expected)
assert col_names == expected
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm, db_extra=None))
assert BaseEngineSpec.convert_dttm("", dttm, db_extra=None) is None
def test_pyodbc_rows_to_tuples(self):
# Test for case when pyodbc.Row is returned (odbc driver)
@@ -272,7 +269,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
]
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
self.assertListEqual(result, expected)
self.assertListEqual(result, expected) # noqa: PT009
def test_pyodbc_rows_to_tuples_passthrough(self):
# Test for case when tuples are returned
@@ -281,7 +278,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
(2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
]
result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
self.assertListEqual(result, data)
self.assertListEqual(result, data) # noqa: PT009
@mock.patch("superset.models.core.Database.db_engine_spec", BaseEngineSpec)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")

View File

@@ -33,4 +33,4 @@ class TestDbEngineSpec(SupersetTestCase):
):
main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force)
self.assertEqual(expected_sql, limited)
assert expected_sql == limited

View File

@@ -45,7 +45,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
}
for original, expected in test_cases.items():
actual = BigQueryEngineSpec.make_label_compatible(column(original).name)
self.assertEqual(actual, expected)
assert actual == expected
def test_timegrain_expressions(self):
"""
@@ -63,7 +63,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
actual = BigQueryEngineSpec.get_timestamp_expr(
col=col, pdf=None, time_grain="PT1H"
)
self.assertEqual(str(actual), expected)
assert str(actual) == expected
def test_custom_minute_timegrain_expressions(self):
"""
@@ -104,12 +104,12 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
data1 = [(1, "foo")]
with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data1):
result = BigQueryEngineSpec.fetch_data(None, 0)
self.assertEqual(result, data1)
assert result == data1
data2 = [Row(1), Row(2)]
with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data2):
result = BigQueryEngineSpec.fetch_data(None, 0)
self.assertEqual(result, [1, 2])
assert result == [1, 2]
def test_get_extra_table_metadata(self):
"""
@@ -122,7 +122,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
database,
Table("some_table", "some_schema"),
)
self.assertEqual(result, {})
assert result == {}
index_metadata = [
{
@@ -143,7 +143,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
database,
Table("some_table", "some_schema"),
)
self.assertEqual(result, expected_result)
assert result == expected_result
def test_get_indexes(self):
database = mock.Mock()

View File

@@ -40,4 +40,4 @@ class TestElasticsearchDbEngineSpec(TestDbEngineSpec):
actual = ElasticSearchEngineSpec.get_timestamp_expr(
col=col, pdf=None, time_grain=time_grain
)
self.assertEqual(str(actual), expected_time_grain_expression)
assert str(actual) == expected_time_grain_expression

View File

@@ -30,8 +30,8 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
)
def test_get_datatype_mysql(self):
"""Tests related to datatype mapping for MySQL"""
self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
assert "TINY" == MySQLEngineSpec.get_datatype(1)
assert "VARCHAR" == MySQLEngineSpec.get_datatype(15)
def test_column_datatype_to_string(self):
test_cases = (
@@ -49,7 +49,7 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
actual = MySQLEngineSpec.column_datatype_to_string(
original, mysql.dialect()
)
self.assertEqual(actual, expected)
assert actual == expected
def test_extract_error_message(self):
from MySQLdb._exceptions import OperationalError

View File

@@ -32,20 +32,14 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
+ "DATETIMECONVERT(tstamp, '1:SECONDS:EPOCH', "
+ "'1:SECONDS:EPOCH', '1:SECONDS') AS TIMESTAMP)) AS TIMESTAMP)"
)
self.assertEqual(
result,
expected,
)
assert result == expected
def test_pinot_time_expression_simple_date_format_1d_grain(self):
col = column("tstamp")
expr = PinotEngineSpec.get_timestamp_expr(col, "%Y-%m-%d %H:%M:%S", "P1D")
result = str(expr.compile())
expected = "CAST(DATE_TRUNC('day', CAST(tstamp AS TIMESTAMP)) AS TIMESTAMP)"
self.assertEqual(
result,
expected,
)
assert result == expected
def test_pinot_time_expression_simple_date_format_10m_grain(self):
col = column("tstamp")
@@ -55,20 +49,14 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
"CAST(ROUND(DATE_TRUNC('minute', CAST(tstamp AS "
+ "TIMESTAMP)), 600000) AS TIMESTAMP)"
)
self.assertEqual(
result,
expected,
)
assert result == expected
def test_pinot_time_expression_simple_date_format_1w_grain(self):
col = column("tstamp")
expr = PinotEngineSpec.get_timestamp_expr(col, "%Y-%m-%d %H:%M:%S", "P1W")
result = str(expr.compile())
expected = "CAST(DATE_TRUNC('week', CAST(tstamp AS TIMESTAMP)) AS TIMESTAMP)"
self.assertEqual(
result,
expected,
)
assert result == expected
def test_pinot_time_expression_sec_one_1m_grain(self):
col = column("tstamp")
@@ -79,10 +67,7 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
+ "DATETIMECONVERT(tstamp, '1:SECONDS:EPOCH', "
+ "'1:SECONDS:EPOCH', '1:SECONDS') AS TIMESTAMP)) AS TIMESTAMP)"
)
self.assertEqual(
result,
expected,
)
assert result == expected
def test_pinot_time_expression_millisec_one_1m_grain(self):
col = column("tstamp")
@@ -93,10 +78,7 @@ class TestPinotDbEngineSpec(TestDbEngineSpec):
+ "DATETIMECONVERT(tstamp, '1:MILLISECONDS:EPOCH', "
+ "'1:MILLISECONDS:EPOCH', '1:MILLISECONDS') AS TIMESTAMP)) AS TIMESTAMP)"
)
self.assertEqual(
result,
expected,
)
assert result == expected
def test_invalid_get_time_expression_arguments(self):
with self.assertRaises(NotImplementedError):

View File

@@ -57,7 +57,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
col = literal_column("COALESCE(a, b)")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "COALESCE(a, b)")
assert result == "COALESCE(a, b)"
def test_time_exp_literal_1y_grain(self):
"""
@@ -66,7 +66,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
col = literal_column("COALESCE(a, b)")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
assert result == "DATE_TRUNC('year', COALESCE(a, b))"
def test_time_ex_lowr_col_no_grain(self):
"""
@@ -75,7 +75,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
col = column("lower_case")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "lower_case")
assert result == "lower_case"
def test_time_exp_lowr_col_sec_1y(self):
"""
@@ -84,10 +84,9 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
col = column("lower_case")
expr = PostgresEngineSpec.get_timestamp_expr(col, "epoch_s", "P1Y")
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(
result,
"DATE_TRUNC('year', "
"(timestamp 'epoch' + lower_case * interval '1 second'))",
assert (
result == "DATE_TRUNC('year', "
"(timestamp 'epoch' + lower_case * interval '1 second'))"
)
def test_time_exp_mixed_case_col_1y(self):
@@ -97,7 +96,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
col = column("MixedCase")
expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y")
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
assert result == "DATE_TRUNC('year', \"MixedCase\")"
def test_empty_dbapi_cursor_description(self):
"""
@@ -107,7 +106,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
# empty description mean no columns, this mocks the following SQL: "SELECT"
cursor.description = []
results = PostgresEngineSpec.fetch_data(cursor, 1000)
self.assertEqual(results, [])
assert results == []
def test_engine_alias_name(self):
"""
@@ -158,13 +157,7 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
)
sql = "SELECT * FROM birth_names"
results = PostgresEngineSpec.estimate_statement_cost(sql, cursor)
self.assertEqual(
results,
{
"Start-up cost": 0.00,
"Total cost": 1537.91,
},
)
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
def test_estimate_statement_invalid_syntax(self):
"""
@@ -199,19 +192,10 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
},
]
result = PostgresEngineSpec.query_cost_formatter(raw_cost)
self.assertEqual(
result,
[
{
"Start-up cost": "0.0",
"Total cost": "1537.91",
},
{
"Start-up cost": "10.0",
"Total cost": "1537.0",
},
],
)
assert result == [
{"Start-up cost": "0.0", "Total cost": "1537.91"},
{"Start-up cost": "10.0", "Total cost": "1537.0"},
]
def test_extract_errors(self):
"""

View File

@@ -33,7 +33,7 @@ from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestPrestoDbEngineSpec(TestDbEngineSpec):
@skipUnless(TestDbEngineSpec.is_module_installed("pyhive"), "pyhive not installed")
def test_get_datatype_presto(self):
self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string"))
assert "STRING" == PrestoEngineSpec.get_datatype("string")
def test_get_view_names_with_schema(self):
database = mock.MagicMock()
@@ -86,10 +86,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
row.Column, row.Type, row.Null = column
inspector.bind.execute.return_value.fetchall = mock.Mock(return_value=[row])
results = PrestoEngineSpec.get_columns(inspector, Table("", ""))
self.assertEqual(len(expected_results), len(results))
assert len(expected_results) == len(results)
for expected_result, result in zip(expected_results, results):
self.assertEqual(expected_result[0], result["column_name"])
self.assertEqual(expected_result[1], str(result["type"]))
assert expected_result[0] == result["column_name"]
assert expected_result[1] == str(result["type"])
def test_presto_get_column(self):
presto_column = ("column_name", "boolean", "")
@@ -192,8 +192,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
},
]
for actual_result, expected_result in zip(actual_results, expected_results):
self.assertEqual(actual_result.element.name, expected_result["column_name"])
self.assertEqual(actual_result.name, expected_result["label"])
assert actual_result.element.name == expected_result["column_name"]
assert actual_result.name == expected_result["label"]
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@@ -260,9 +260,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"is_dttm": False,
}
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
assert actual_cols == expected_cols
assert actual_data == expected_data
assert actual_expanded_cols == expected_expanded_cols
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@@ -343,9 +343,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"is_dttm": False,
},
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
assert actual_cols == expected_cols
assert actual_data == expected_data
assert actual_expanded_cols == expected_expanded_cols
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@@ -427,9 +427,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"is_dttm": False,
},
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
assert actual_cols == expected_cols
assert actual_data == expected_data
assert actual_expanded_cols == expected_expanded_cols
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@@ -548,9 +548,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"is_dttm": False,
},
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
assert actual_cols == expected_cols
assert actual_data == expected_data
assert actual_expanded_cols == expected_expanded_cols
def test_presto_get_extra_table_metadata(self):
database = mock.Mock()
@@ -582,7 +582,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
columns,
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)
assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result
def test_query_cost_formatter(self):
raw_cost = [
@@ -645,7 +645,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
"Network cost": "354 G",
}
]
self.assertEqual(formatted_cost, expected)
assert formatted_cost == expected
@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags",
@@ -752,9 +752,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
},
]
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
assert actual_cols == expected_cols
assert actual_data == expected_data
assert actual_expanded_cols == expected_expanded_cols
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")