test: autouse app_context in unit tests (#20911)

This commit is contained in:
Jesse Yang
2022-08-02 15:42:50 -07:00
committed by GitHub
parent c06d5eb70c
commit 7e836e9b04
37 changed files with 142 additions and 212 deletions

View File

@@ -18,8 +18,6 @@
import re
from datetime import datetime
from flask.ctx import AppContext
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.unit_tests.fixtures.common import dttm
@@ -28,7 +26,7 @@ SYNTAX_ERROR_REGEX = re.compile(
)
def test_convert_dttm(app_context: AppContext, dttm: datetime) -> None:
def test_convert_dttm(dttm: datetime) -> None:
"""
Test that date objects are converted correctly.
"""
@@ -43,7 +41,7 @@ def test_convert_dttm(app_context: AppContext, dttm: datetime) -> None:
)
def test_extract_errors(app_context: AppContext) -> None:
def test_extract_errors() -> None:
"""
Test that custom error messages are extracted correctly.
"""
@@ -70,7 +68,7 @@ def test_extract_errors(app_context: AppContext) -> None:
]
def test_get_text_clause_with_colon(app_context: AppContext) -> None:
def test_get_text_clause_with_colon() -> None:
"""
Make sure text clauses don't escape the colon character
"""

View File

@@ -19,11 +19,10 @@
from textwrap import dedent
import pytest
from flask.ctx import AppContext
from sqlalchemy.types import TypeEngine
def test_get_text_clause_with_colon(app_context: AppContext) -> None:
def test_get_text_clause_with_colon() -> None:
"""
Make sure text clauses are correctly escaped
"""
@@ -36,7 +35,7 @@ def test_get_text_clause_with_colon(app_context: AppContext) -> None:
assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')"
def test_parse_sql_single_statement(app_context: AppContext) -> None:
def test_parse_sql_single_statement() -> None:
"""
`parse_sql` should properly strip leading and trailing spaces and semicolons
"""
@@ -47,7 +46,7 @@ def test_parse_sql_single_statement(app_context: AppContext) -> None:
assert queries == ["SELECT foo FROM tbl"]
def test_parse_sql_multi_statement(app_context: AppContext) -> None:
def test_parse_sql_multi_statement() -> None:
"""
For string with multiple SQL-statements `parse_sql` method should return list
where each element represents the single SQL-statement
@@ -95,9 +94,7 @@ select 'USD' as cur
),
],
)
def test_cte_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str
) -> None:
def test_cte_query_parsing(original: TypeEngine, expected: str) -> None:
from superset.db_engine_specs.base import BaseEngineSpec
actual = BaseEngineSpec.get_cte_query(original)

View File

@@ -16,14 +16,13 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from flask.ctx import AppContext
from pybigquery.sqlalchemy_bigquery import BigQueryDialect
from pytest_mock import MockFixture
from sqlalchemy import select
from sqlalchemy.sql import sqltypes
def test_get_fields(app_context: AppContext) -> None:
def test_get_fields() -> None:
"""
Test the custom ``_get_fields`` method.
@@ -66,7 +65,7 @@ def test_get_fields(app_context: AppContext) -> None:
)
def test_select_star(mocker: MockFixture, app_context: AppContext) -> None:
def test_select_star(mocker: MockFixture) -> None:
"""
Test the ``select_star`` method.

View File

@@ -16,11 +16,10 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from flask.ctx import AppContext
from pytest import raises
def test_odbc_impersonation(app_context: AppContext) -> None:
def test_odbc_impersonation() -> None:
"""
Test ``get_url_for_impersonation`` method when driver == odbc.
@@ -36,7 +35,7 @@ def test_odbc_impersonation(app_context: AppContext) -> None:
assert url.query["DelegationUID"] == username
def test_jdbc_impersonation(app_context: AppContext) -> None:
def test_jdbc_impersonation() -> None:
"""
Test ``get_url_for_impersonation`` method when driver == jdbc.
@@ -52,7 +51,7 @@ def test_jdbc_impersonation(app_context: AppContext) -> None:
assert url.query["impersonation_target"] == username
def test_sadrill_impersonation(app_context: AppContext) -> None:
def test_sadrill_impersonation() -> None:
"""
Test ``get_url_for_impersonation`` method when driver == sadrill.
@@ -68,7 +67,7 @@ def test_sadrill_impersonation(app_context: AppContext) -> None:
assert url.query["impersonation_target"] == username
def test_invalid_impersonation(app_context: AppContext) -> None:
def test_invalid_impersonation() -> None:
"""
Test ``get_url_for_impersonation`` method when driver == foobar.

View File

@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask.ctx import AppContext
from pytest_mock import MockFixture
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@@ -28,7 +27,6 @@ class ProgrammingError(Exception):
def test_validate_parameters_simple(
mocker: MockFixture,
app_context: AppContext,
) -> None:
from superset.db_engine_specs.gsheets import (
GSheetsEngineSpec,
@@ -52,7 +50,6 @@ def test_validate_parameters_simple(
def test_validate_parameters_catalog(
mocker: MockFixture,
app_context: AppContext,
) -> None:
from superset.db_engine_specs.gsheets import (
GSheetsEngineSpec,
@@ -143,7 +140,6 @@ def test_validate_parameters_catalog(
def test_validate_parameters_catalog_and_credentials(
mocker: MockFixture,
app_context: AppContext,
) -> None:
from superset.db_engine_specs.gsheets import (
GSheetsEngineSpec,

View File

@@ -18,7 +18,6 @@
from datetime import datetime
import pytest
from flask.ctx import AppContext
from tests.unit_tests.fixtures.common import dttm
@@ -32,9 +31,7 @@ from tests.unit_tests.fixtures.common import dttm
("INSERT INTO tbl (foo) VALUES (1)", False),
],
)
def test_sql_is_readonly_query(
app_context: AppContext, sql: str, expected: bool
) -> None:
def test_sql_is_readonly_query(sql: str, expected: bool) -> None:
"""
Make sure that SQL dialect consider only SELECT statements as read-only
"""
@@ -56,7 +53,7 @@ def test_sql_is_readonly_query(
(".show tables", False),
],
)
def test_kql_is_select_query(app_context: AppContext, kql: str, expected: bool) -> None:
def test_kql_is_select_query(kql: str, expected: bool) -> None:
"""
Make sure that KQL dialect consider only statements that do not start with "." (dot)
as a SELECT statements
@@ -83,9 +80,7 @@ def test_kql_is_select_query(app_context: AppContext, kql: str, expected: bool)
(".set-or-append table foo <| bar", False),
],
)
def test_kql_is_readonly_query(
app_context: AppContext, kql: str, expected: bool
) -> None:
def test_kql_is_readonly_query(kql: str, expected: bool) -> None:
"""
Make sure that KQL dialect consider only SELECT statements as read-only
"""
@@ -99,7 +94,7 @@ def test_kql_is_readonly_query(
assert expected == is_readonly
def test_kql_parse_sql(app_context: AppContext) -> None:
def test_kql_parse_sql() -> None:
"""
parse_sql method should always return a list with a single element
which is an original query
@@ -121,7 +116,6 @@ def test_kql_parse_sql(app_context: AppContext) -> None:
],
)
def test_kql_convert_dttm(
app_context: AppContext,
target_type: str,
expected_dttm: str,
dttm: datetime,
@@ -145,7 +139,6 @@ def test_kql_convert_dttm(
],
)
def test_sql_convert_dttm(
app_context: AppContext,
target_type: str,
expected_dttm: str,
dttm: datetime,

View File

@@ -19,7 +19,6 @@ from datetime import datetime
from textwrap import dedent
import pytest
from flask.ctx import AppContext
from sqlalchemy import column, table
from sqlalchemy.dialects import mssql
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
@@ -44,7 +43,6 @@ from tests.unit_tests.fixtures.common import dttm
],
)
def test_mssql_column_types(
app_context: AppContext,
type_string: str,
type_expected: TypeEngine,
generic_type_expected: GenericDataType,
@@ -61,7 +59,7 @@ def test_mssql_column_types(
assert column_spec.generic_type == generic_type_expected
def test_where_clause_n_prefix(app_context: AppContext) -> None:
def test_where_clause_n_prefix() -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
dialect = mssql.dialect()
@@ -95,7 +93,7 @@ def test_where_clause_n_prefix(app_context: AppContext) -> None:
assert query == query_expected
def test_time_exp_mixd_case_col_1y(app_context: AppContext) -> None:
def test_time_exp_mixd_case_col_1y() -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
col = column("MixedCase")
@@ -122,7 +120,6 @@ def test_time_exp_mixd_case_col_1y(app_context: AppContext) -> None:
],
)
def test_convert_dttm(
app_context: AppContext,
actual: str,
expected: str,
dttm: datetime,
@@ -132,7 +129,7 @@ def test_convert_dttm(
assert MssqlEngineSpec.convert_dttm(actual, dttm) == expected
def test_extract_error_message(app_context: AppContext) -> None:
def test_extract_error_message() -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
test_mssql_exception = Exception(
@@ -158,7 +155,7 @@ def test_extract_error_message(app_context: AppContext) -> None:
assert expected_message == error_message
def test_fetch_data(app_context: AppContext) -> None:
def test_fetch_data() -> None:
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
@@ -185,9 +182,7 @@ def test_fetch_data(app_context: AppContext) -> None:
(NTEXT(collation="utf8_general_ci"), "NTEXT"),
],
)
def test_column_datatype_to_string(
app_context: AppContext, original: TypeEngine, expected: str
) -> None:
def test_column_datatype_to_string(original: TypeEngine, expected: str) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
actual = MssqlEngineSpec.column_datatype_to_string(original, mssql.dialect())
@@ -239,9 +234,7 @@ select 'USD' as cur
),
],
)
def test_cte_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str
) -> None:
def test_cte_query_parsing(original: TypeEngine, expected: str) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
actual = MssqlEngineSpec.get_cte_query(original)
@@ -270,16 +263,14 @@ select TOP 100 * from currency""",
),
],
)
def test_top_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str, top: int
) -> None:
def test_top_query_parsing(original: TypeEngine, expected: str, top: int) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
actual = MssqlEngineSpec.apply_top_to_sql(original, top)
assert actual == expected
def test_extract_errors(app_context: AppContext) -> None:
def test_extract_errors() -> None:
"""
Test that custom error messages are extracted correctly.
"""

View File

@@ -19,7 +19,6 @@ from typing import Optional
import pytest
import pytz
from flask.ctx import AppContext
@pytest.mark.parametrize(
@@ -45,7 +44,6 @@ from flask.ctx import AppContext
],
)
def test_convert_dttm(
app_context: AppContext,
target_type: str,
dttm: datetime,
result: Optional[str],

View File

@@ -19,7 +19,6 @@ from datetime import datetime
from unittest import mock
import pytest
from flask.ctx import AppContext
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.unit_tests.fixtures.common import dttm
@@ -33,15 +32,13 @@ from tests.unit_tests.fixtures.common import dttm
("TIMESTAMP", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
],
)
def test_convert_dttm(
app_context: AppContext, actual: str, expected: str, dttm: datetime
) -> None:
def test_convert_dttm(actual: str, expected: str, dttm: datetime) -> None:
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
assert SnowflakeEngineSpec.convert_dttm(actual, dttm) == expected
def test_database_connection_test_mutator(app_context: AppContext) -> None:
def test_database_connection_test_mutator() -> None:
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
from superset.models.core import Database
@@ -54,7 +51,7 @@ def test_database_connection_test_mutator(app_context: AppContext) -> None:
} == engine_params
def test_extract_errors(app_context: AppContext) -> None:
def test_extract_errors() -> None:
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
msg = "Object dumbBrick does not exist or not authorized."

View File

@@ -19,31 +19,30 @@ from datetime import datetime
from unittest import mock
import pytest
from flask.ctx import AppContext
from sqlalchemy.engine import create_engine
from tests.unit_tests.fixtures.common import dttm
def test_convert_dttm(app_context: AppContext, dttm: datetime) -> None:
def test_convert_dttm(dttm: datetime) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
assert SqliteEngineSpec.convert_dttm("TEXT", dttm) == "'2019-01-02 03:04:05.678900'"
def test_convert_dttm_lower(app_context: AppContext, dttm: datetime) -> None:
def test_convert_dttm_lower(dttm: datetime) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
assert SqliteEngineSpec.convert_dttm("text", dttm) == "'2019-01-02 03:04:05.678900'"
def test_convert_dttm_invalid_type(app_context: AppContext, dttm: datetime) -> None:
def test_convert_dttm_invalid_type(dttm: datetime) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
assert SqliteEngineSpec.convert_dttm("other", dttm) is None
def test_get_all_datasource_names_table(app_context: AppContext) -> None:
def test_get_all_datasource_names_table() -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
database = mock.MagicMock()
@@ -62,7 +61,7 @@ def test_get_all_datasource_names_table(app_context: AppContext) -> None:
)
def test_get_all_datasource_names_view(app_context: AppContext) -> None:
def test_get_all_datasource_names_view() -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
database = mock.MagicMock()
@@ -81,7 +80,7 @@ def test_get_all_datasource_names_view(app_context: AppContext) -> None:
)
def test_get_all_datasource_names_invalid_type(app_context: AppContext) -> None:
def test_get_all_datasource_names_invalid_type() -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
database = mock.MagicMock()
@@ -132,9 +131,7 @@ def test_get_all_datasource_names_invalid_type(app_context: AppContext) -> None:
("2022-12-04T05:06:07.89Z", "P3M", "2022-10-01 00:00:00"),
],
)
def test_time_grain_expressions(
dttm: str, grain: str, expected: str, app_context: AppContext
) -> None:
def test_time_grain_expressions(dttm: str, grain: str, expected: str) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
engine = create_engine("sqlite://")

View File

@@ -16,7 +16,6 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import pytest
from flask.ctx import AppContext
@pytest.mark.parametrize(
@@ -32,7 +31,6 @@ from flask.ctx import AppContext
],
)
def test_apply_top_to_sql_limit(
app_context: AppContext,
limit: int,
original: str,
expected: str,