[mypy] Enforcing typing for db_engine_specs (#9138)

This commit is contained in:
John Bodley
2020-02-17 23:08:11 -08:00
committed by GitHub
parent 3149d8ebc0
commit 9f5f8e5d92
17 changed files with 173 additions and 104 deletions

View File

@@ -22,15 +22,19 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from urllib import parse
import pandas as pd
from sqlalchemy import Column
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from wtforms.form import Form
from superset import app, cache, conf
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.models.sql_lab import Query
from superset.utils import core as utils
if TYPE_CHECKING:
@@ -67,7 +71,7 @@ class HiveEngineSpec(PrestoEngineSpec):
)
@classmethod
def patch(cls):
def patch(cls) -> None:
from pyhive import hive # pylint: disable=no-name-in-module
from superset.db_engines import hive as patched_hive
from TCLIService import (
@@ -83,12 +87,12 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_all_datasource_names(
cls, database, datasource_type: str
cls, database: "Database", datasource_type: str
) -> List[utils.DatasourceName]:
return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
@classmethod
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
import pyhive
from TCLIService import ttypes
@@ -102,11 +106,11 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def create_table_from_csv( # pylint: disable=too-many-locals
cls, form, database
cls, form: Form, database: "Database"
) -> None:
"""Uploads a csv file and creates a superset datasource in Hive."""
def convert_to_hive_type(col_type):
def convert_to_hive_type(col_type: str) -> str:
"""maps tableschema's types to hive types"""
tableschema_to_hive_types = {
"boolean": "BOOLEAN",
@@ -192,13 +196,14 @@ class HiveEngineSpec(PrestoEngineSpec):
return None
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> None:
if selected_schema:
uri.database = parse.quote(selected_schema, safe="")
return uri
@classmethod
def _extract_error_message(cls, e):
def _extract_error_message(cls, e: Exception) -> str:
msg = str(e)
match = re.search(r'errorMessage="(.*?)(?<!\\)"', msg)
if match:
@@ -206,10 +211,10 @@ class HiveEngineSpec(PrestoEngineSpec):
return msg
@classmethod
def progress(cls, log_lines):
def progress(cls, log_lines: List[str]) -> int:
total_jobs = 1 # assuming there's at least 1 job
current_job = 1
stages = {}
stages: Dict[int, float] = {}
for line in log_lines:
match = cls.jobs_stats_r.match(line)
if match:
@@ -237,15 +242,17 @@ class HiveEngineSpec(PrestoEngineSpec):
return int(progress)
@classmethod
def get_tracking_url(cls, log_lines):
def get_tracking_url(cls, log_lines: List[str]) -> Optional[str]:
lkp = "Tracking URL = "
for line in log_lines:
if lkp in line:
return line.split(lkp)[1]
return None
return None
@classmethod
def handle_cursor(cls, cursor, query, session): # pylint: disable=too-many-locals
def handle_cursor( # pylint: disable=too-many-locals
cls, cursor: Any, query: Query, session: Session
) -> None:
"""Updates progress information"""
from pyhive import hive # pylint: disable=no-name-in-module
@@ -310,7 +317,7 @@ class HiveEngineSpec(PrestoEngineSpec):
cls,
table_name: str,
schema: Optional[str],
database,
database: "Database",
query: Select,
columns: Optional[List] = None,
) -> Optional[Select]:
@@ -335,12 +342,14 @@ class HiveEngineSpec(PrestoEngineSpec):
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
@classmethod
def latest_sub_partition(cls, table_name, schema, database, **kwargs):
def latest_sub_partition(
cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
) -> str:
# TODO(bogdan): implement`
pass
@classmethod
def _latest_partition_from_df(cls, df) -> Optional[List[str]]:
def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
"""Hive partitions look like ds={partition name}"""
if not df.empty:
return [df.ix[:, 0].max().split("=")[1]]
@@ -348,14 +357,19 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def _partition_query( # pylint: disable=too-many-arguments
cls, table_name, database, limit=0, order_by=None, filters=None
):
cls,
table_name: str,
database: "Database",
limit: int = 0,
order_by: Optional[List[Tuple[str, bool]]] = None,
filters: Optional[Dict[Any, Any]] = None,
) -> str:
return f"SHOW PARTITIONS {table_name}"
@classmethod
def select_star( # pylint: disable=too-many-arguments
cls,
database,
database: "Database",
table_name: str,
engine: Engine,
schema: Optional[str] = None,
@@ -381,8 +395,8 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def modify_url_for_impersonation(
cls, url, impersonate_user: bool, username: Optional[str]
):
cls, url: URL, impersonate_user: bool, username: Optional[str]
) -> None:
"""
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
:param url: SQLAlchemy URL object