style(mypy): Spit-and-polish pass (#10001)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley
2020-06-07 08:53:46 -07:00
committed by GitHub
parent 656cdfb867
commit 91517a56a3
56 changed files with 243 additions and 207 deletions

View File

@@ -27,8 +27,9 @@ def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-
def memoized_func(
key: Callable = view_cache_key, attribute_in_key: Optional[str] = None
) -> Callable:
key: Callable[..., str] = view_cache_key, # pylint: disable=bad-whitespace
attribute_in_key: Optional[str] = None,
) -> Callable[..., Any]:
"""Use this decorator to cache functions that have predefined first arg.
enable_cache is treated as True by default,
@@ -45,7 +46,7 @@ def memoized_func(
returns the caching key.
"""
def wrap(f: Callable) -> Callable:
def wrap(f: Callable[..., Any]) -> Callable[..., Any]:
if cache_manager.tables_cache:
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:

View File

@@ -85,7 +85,7 @@ from superset.exceptions import (
SupersetException,
SupersetTimeoutException,
)
from superset.typing import FormData, Metric
from superset.typing import FlaskResponse, FormData, Metric
from superset.utils.dates import datetime_to_epoch, EPOCH
try:
@@ -147,7 +147,9 @@ class _memoized:
should account for instance variable changes.
"""
def __init__(self, func: Callable, watch: Optional[Tuple[str, ...]] = None) -> None:
def __init__(
self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None
) -> None:
self.func = func
self.cache: Dict[Any, Any] = {}
self.is_method = False
@@ -173,7 +175,7 @@ class _memoized:
"""Return the function's docstring."""
return self.func.__doc__ or ""
def __get__(self, obj: Any, objtype: Type) -> functools.partial:
def __get__(self, obj: Any, objtype: Type[Any]) -> functools.partial: # type: ignore
if not self.is_method:
self.is_method = True
"""Support instance methods."""
@@ -181,13 +183,13 @@ class _memoized:
def memoized(
func: Optional[Callable] = None, watch: Optional[Tuple[str, ...]] = None
) -> Callable:
func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None
) -> Callable[..., Any]:
if func:
return _memoized(func)
else:
def wrapper(f: Callable) -> Callable:
def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
return _memoized(f, watch)
return wrapper
@@ -1241,7 +1243,9 @@ def create_ssl_cert_file(certificate: str) -> str:
return path
def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]:
def time_function(
func: Callable[..., FlaskResponse], *args: Any, **kwargs: Any
) -> Tuple[float, Any]:
"""
Measures the amount of time a function takes to execute in ms

View File

@@ -29,7 +29,7 @@ def convert_filter_scopes(
) -> Dict[int, Dict[str, Dict[str, Any]]]:
filter_scopes = {}
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
immuned_by_column: Dict = defaultdict(list)
immuned_by_column: Dict[str, List[int]] = defaultdict(list)
for slice_id, columns in json_metadata.get(
"filter_immune_slice_fields", {}
).items():
@@ -52,7 +52,7 @@ def convert_filter_scopes(
logging.info(f"slice [{filter_id}] has invalid field: {filter_field}")
for filter_slice in filters:
filter_fields: Dict = {}
filter_fields: Dict[str, Dict[str, Any]] = {}
filter_id = filter_slice.id
slice_params = json.loads(filter_slice.params or "{}")
configs = slice_params.get("filter_configs") or []
@@ -77,9 +77,10 @@ def convert_filter_scopes(
def copy_filter_scopes(
old_to_new_slc_id_dict: Dict[int, int], old_filter_scopes: Dict[str, Dict]
) -> Dict:
new_filter_scopes: Dict[str, Dict] = {}
old_to_new_slc_id_dict: Dict[int, int],
old_filter_scopes: Dict[int, Dict[str, Dict[str, Any]]],
) -> Dict[str, Dict[Any, Any]]:
new_filter_scopes: Dict[str, Dict[Any, Any]] = {}
for (filter_id, scopes) in old_filter_scopes.items():
new_filter_key = old_to_new_slc_id_dict.get(int(filter_id))
if new_filter_key:

View File

@@ -46,7 +46,7 @@ def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[floa
stats_logger.timing(stats_key, now_as_float() - start_ts)
def etag_cache(max_age: int, check_perms: Callable) -> Callable:
def etag_cache(max_age: int, check_perms: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator for caching views and handling etag conditional requests.
@@ -60,7 +60,7 @@ def etag_cache(max_age: int, check_perms: Callable) -> Callable:
"""
def decorator(f: Callable) -> Callable:
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
# check if the user can access the resource

View File

@@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
def import_datasource(
session: Session,
i_datasource: Model,
lookup_database: Callable,
lookup_datasource: Callable,
lookup_database: Callable[[Model], Model],
lookup_datasource: Callable[[Model], Model],
import_time: Optional[int] = None,
) -> int:
"""Imports the datasource from the object to the database.
@@ -82,7 +82,9 @@ def import_datasource(
return datasource.id
def import_simple_obj(session: Session, i_obj: Model, lookup_obj: Callable) -> Model:
def import_simple_obj(
session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
) -> Model:
make_transient(i_obj)
i_obj.id = None
i_obj.table = None

View File

@@ -35,7 +35,7 @@ class AbstractEventLogger(ABC):
) -> None:
pass
def log_this(self, f: Callable) -> Callable:
def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
user_id = None
@@ -124,7 +124,7 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
)
)
event_logger_type = cast(Type, cfg_value)
event_logger_type = cast(Type[Any], cfg_value)
result = event_logger_type()
# Verify that we have a valid logger impl

View File

@@ -58,7 +58,7 @@ class DefaultLoggingConfigurator(LoggingConfigurator):
if app_config["ENABLE_TIME_ROTATE"]:
logging.getLogger().setLevel(app_config["TIME_ROTATE_LOG_LEVEL"])
handler = TimedRotatingFileHandler( # type: ignore
handler = TimedRotatingFileHandler(
app_config["FILENAME"],
when=app_config["ROLLOVER"],
interval=app_config["INTERVAL"],

View File

@@ -72,8 +72,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
)
def validate_column_args(*argnames: str) -> Callable:
def wrapper(func: Callable) -> Callable:
def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
columns = df.columns.tolist()
for name in argnames:
@@ -471,7 +471,7 @@ def geodetic_parse(
Parse a string containing a geodetic point and return latitude, longitude
and altitude
"""
point = Point(location) # type: ignore
point = Point(location)
return point[0], point[1], point[2]
try:

View File

@@ -51,7 +51,7 @@ SELENIUM_HEADSTART = 3
WindowSize = Tuple[int, int]
def get_auth_cookies(user: "User") -> List[Dict]:
def get_auth_cookies(user: "User") -> List[Dict[Any, Any]]:
# Login with the user specified to get the reports
with current_app.test_request_context("/login"):
login_user(user)
@@ -101,14 +101,14 @@ class AuthWebDriverProxy:
self,
driver_type: str,
window: Optional[WindowSize] = None,
auth_func: Optional[Callable] = None,
auth_func: Optional[
Callable[..., Any]
] = None, # pylint: disable=bad-whitespace
):
self._driver_type = driver_type
self._window: WindowSize = window or (800, 600)
config_auth_func: Callable = current_app.config.get(
"WEBDRIVER_AUTH_FUNC", auth_driver
)
self._auth_func: Callable = auth_func or config_auth_func
config_auth_func = current_app.config.get("WEBDRIVER_AUTH_FUNC", auth_driver)
self._auth_func = auth_func or config_auth_func
def create(self) -> WebDriver:
if self._driver_type == "firefox":
@@ -123,7 +123,7 @@ class AuthWebDriverProxy:
raise Exception(f"Webdriver name ({self._driver_type}) not supported")
# Prepare args for the webdriver init
options.add_argument("--headless")
kwargs: Dict = dict(options=options)
kwargs: Dict[Any, Any] = dict(options=options)
kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
logger.info("Init selenium driver")
return driver_class(**kwargs)