style(mypy): Enforcing typing for superset (#9943)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley
2020-06-03 15:26:12 -07:00
committed by GitHub
parent dcac860f3e
commit 244677cf5e
15 changed files with 393 additions and 313 deletions

View File

@@ -19,7 +19,7 @@ import uuid
from contextlib import closing
from datetime import datetime
from sys import getsizeof
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
import backoff
import msgpack
@@ -27,9 +27,10 @@ import pyarrow as pa
import simplejson as json
import sqlalchemy
from celery.exceptions import SoftTimeLimitExceeded
from celery.task.base import Task
from contextlib2 import contextmanager
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import NullPool
from superset import (
@@ -77,7 +78,9 @@ class SqlLabTimeoutException(SqlLabException):
pass
def handle_query_error(msg, query, session, payload=None):
def handle_query_error(
msg: str, query: Query, session: Session, payload: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Local method handling error while processing the SQL"""
payload = payload or {}
troubleshooting_link = config["TROUBLESHOOTING_LINK"]
@@ -91,14 +94,14 @@ def handle_query_error(msg, query, session, payload=None):
return payload
def get_query_backoff_handler(details):
def get_query_backoff_handler(details: Dict[Any, Any]) -> None:
query_id = details["kwargs"]["query_id"]
logger.error(f"Query with id `{query_id}` could not be retrieved")
stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1))
logger.error(f"Query {query_id}: Sleeping for a sec before retrying...")
def get_query_giveup_handler(_):
def get_query_giveup_handler(_: Any) -> None:
stats_logger.incr("error_failed_at_getting_orm_query")
@@ -110,7 +113,7 @@ def get_query_giveup_handler(_):
on_giveup=get_query_giveup_handler,
max_tries=5,
)
def get_query(query_id, session):
def get_query(query_id: int, session: Session) -> Query:
"""attempts to get the query and retry if it cannot"""
try:
return session.query(Query).filter_by(id=query_id).one()
@@ -119,7 +122,7 @@ def get_query(query_id, session):
@contextmanager
def session_scope(nullpool):
def session_scope(nullpool: bool) -> Iterator[Session]:
"""Provide a transactional scope around a series of operations."""
database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
if "sqlite" in database_uri:
@@ -154,16 +157,16 @@ def session_scope(nullpool):
soft_time_limit=SQLLAB_TIMEOUT,
)
def get_sql_results( # pylint: disable=too-many-arguments
ctask,
query_id,
rendered_query,
return_results=True,
store_results=False,
user_name=None,
start_time=None,
expand_data=False,
log_params=None,
):
ctask: Task,
query_id: int,
rendered_query: str,
return_results: bool = True,
store_results: bool = False,
user_name: Optional[str] = None,
start_time: Optional[float] = None,
expand_data: bool = False,
log_params: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""Executes the sql query returns the results."""
with session_scope(not ctask.request.called_directly) as session:
@@ -188,7 +191,14 @@ def get_sql_results( # pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments
def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_params):
def execute_sql_statement(
sql_statement: str,
query: Query,
user_name: Optional[str],
session: Session,
cursor: Any,
log_params: Optional[Dict[str, Any]],
) -> SupersetResultSet:
"""Executes a single SQL statement"""
database = query.database
db_engine_spec = database.db_engine_spec
@@ -275,7 +285,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_
def _serialize_payload(
payload: dict, use_msgpack: Optional[bool] = False
payload: Dict[Any, Any], use_msgpack: Optional[bool] = False
) -> Union[bytes, str]:
logger.debug(f"Serializing to msgpack: {use_msgpack}")
if use_msgpack:
@@ -321,24 +331,24 @@ def _serialize_and_expand_data(
return (data, selected_columns, all_columns, expanded_columns)
def execute_sql_statements(
query_id,
rendered_query,
return_results=True,
store_results=False,
user_name=None,
session=None,
start_time=None,
expand_data=False,
log_params=None,
): # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-locals, too-many-statements
query_id: int,
rendered_query: str,
return_results: bool,
store_results: bool,
user_name: Optional[str],
session: Session,
start_time: Optional[float],
expand_data: bool,
log_params: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
"""Executes the sql query returns the results."""
if store_results and start_time:
# only asynchronous queries
stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)
query = get_query(query_id, session)
payload = dict(query_id=query_id)
payload: Dict[str, Any] = dict(query_id=query_id)
database = query.database
db_engine_spec = database.db_engine_spec
db_engine_spec.patch()
@@ -406,7 +416,7 @@ def execute_sql_statements(
)
query.end_time = now_as_float()
use_arrow_data = store_results and results_backend_use_msgpack
use_arrow_data = store_results and cast(bool, results_backend_use_msgpack)
data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data(
result_set, db_engine_spec, use_arrow_data, expand_data
)
@@ -432,7 +442,7 @@ def execute_sql_statements(
"sqllab.query.results_backend_write_serialization", stats_logger
):
serialized_payload = _serialize_payload(
payload, results_backend_use_msgpack
payload, cast(bool, results_backend_use_msgpack)
)
cache_timeout = database.cache_timeout
if cache_timeout is None: