mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
[mypy] Enforcing typing for db_engine_specs (#9138)
This commit is contained in:
@@ -22,15 +22,19 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from urllib import parse
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.engine.url import make_url, URL
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select
|
||||
from wtforms.form import Form
|
||||
|
||||
from superset import app, cache, conf
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils import core as utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -67,7 +71,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def patch(cls):
|
||||
def patch(cls) -> None:
|
||||
from pyhive import hive # pylint: disable=no-name-in-module
|
||||
from superset.db_engines import hive as patched_hive
|
||||
from TCLIService import (
|
||||
@@ -83,12 +87,12 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
|
||||
@classmethod
|
||||
def get_all_datasource_names(
|
||||
cls, database, datasource_type: str
|
||||
cls, database: "Database", datasource_type: str
|
||||
) -> List[utils.DatasourceName]:
|
||||
return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
import pyhive
|
||||
from TCLIService import ttypes
|
||||
|
||||
@@ -102,11 +106,11 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
|
||||
@classmethod
|
||||
def create_table_from_csv( # pylint: disable=too-many-locals
|
||||
cls, form, database
|
||||
cls, form: Form, database: "Database"
|
||||
) -> None:
|
||||
"""Uploads a csv file and creates a superset datasource in Hive."""
|
||||
|
||||
def convert_to_hive_type(col_type):
|
||||
def convert_to_hive_type(col_type: str) -> str:
|
||||
"""maps tableschema's types to hive types"""
|
||||
tableschema_to_hive_types = {
|
||||
"boolean": "BOOLEAN",
|
||||
@@ -192,13 +196,14 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
def adjust_database_uri(
|
||||
cls, uri: URL, selected_schema: Optional[str] = None
|
||||
) -> None:
|
||||
if selected_schema:
|
||||
uri.database = parse.quote(selected_schema, safe="")
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def _extract_error_message(cls, e):
|
||||
def _extract_error_message(cls, e: Exception) -> str:
|
||||
msg = str(e)
|
||||
match = re.search(r'errorMessage="(.*?)(?<!\\)"', msg)
|
||||
if match:
|
||||
@@ -206,10 +211,10 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
return msg
|
||||
|
||||
@classmethod
|
||||
def progress(cls, log_lines):
|
||||
def progress(cls, log_lines: List[str]) -> int:
|
||||
total_jobs = 1 # assuming there's at least 1 job
|
||||
current_job = 1
|
||||
stages = {}
|
||||
stages: Dict[int, float] = {}
|
||||
for line in log_lines:
|
||||
match = cls.jobs_stats_r.match(line)
|
||||
if match:
|
||||
@@ -237,15 +242,17 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
return int(progress)
|
||||
|
||||
@classmethod
|
||||
def get_tracking_url(cls, log_lines):
|
||||
def get_tracking_url(cls, log_lines: List[str]) -> Optional[str]:
|
||||
lkp = "Tracking URL = "
|
||||
for line in log_lines:
|
||||
if lkp in line:
|
||||
return line.split(lkp)[1]
|
||||
return None
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor, query, session): # pylint: disable=too-many-locals
|
||||
def handle_cursor( # pylint: disable=too-many-locals
|
||||
cls, cursor: Any, query: Query, session: Session
|
||||
) -> None:
|
||||
"""Updates progress information"""
|
||||
from pyhive import hive # pylint: disable=no-name-in-module
|
||||
|
||||
@@ -310,7 +317,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
cls,
|
||||
table_name: str,
|
||||
schema: Optional[str],
|
||||
database,
|
||||
database: "Database",
|
||||
query: Select,
|
||||
columns: Optional[List] = None,
|
||||
) -> Optional[Select]:
|
||||
@@ -335,12 +342,14 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
|
||||
|
||||
@classmethod
|
||||
def latest_sub_partition(cls, table_name, schema, database, **kwargs):
|
||||
def latest_sub_partition(
|
||||
cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
|
||||
) -> str:
|
||||
# TODO(bogdan): implement`
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _latest_partition_from_df(cls, df) -> Optional[List[str]]:
|
||||
def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
|
||||
"""Hive partitions look like ds={partition name}"""
|
||||
if not df.empty:
|
||||
return [df.ix[:, 0].max().split("=")[1]]
|
||||
@@ -348,14 +357,19 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
|
||||
@classmethod
|
||||
def _partition_query( # pylint: disable=too-many-arguments
|
||||
cls, table_name, database, limit=0, order_by=None, filters=None
|
||||
):
|
||||
cls,
|
||||
table_name: str,
|
||||
database: "Database",
|
||||
limit: int = 0,
|
||||
order_by: Optional[List[Tuple[str, bool]]] = None,
|
||||
filters: Optional[Dict[Any, Any]] = None,
|
||||
) -> str:
|
||||
return f"SHOW PARTITIONS {table_name}"
|
||||
|
||||
@classmethod
|
||||
def select_star( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
database,
|
||||
database: "Database",
|
||||
table_name: str,
|
||||
engine: Engine,
|
||||
schema: Optional[str] = None,
|
||||
@@ -381,8 +395,8 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
|
||||
@classmethod
|
||||
def modify_url_for_impersonation(
|
||||
cls, url, impersonate_user: bool, username: Optional[str]
|
||||
):
|
||||
cls, url: URL, impersonate_user: bool, username: Optional[str]
|
||||
) -> None:
|
||||
"""
|
||||
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
|
||||
:param url: SQLAlchemy URL object
|
||||
|
||||
Reference in New Issue
Block a user