[mypy] Enforcing typing for db_engine_specs (#9138)

This commit is contained in:
John Bodley
2020-02-17 23:08:11 -08:00
committed by GitHub
parent 3149d8ebc0
commit 9f5f8e5d92
17 changed files with 173 additions and 104 deletions

View File

@@ -17,13 +17,17 @@
import hashlib
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
import pandas as pd
from sqlalchemy import literal_column
from sqlalchemy.sql.expression import ColumnClause
from superset.db_engine_specs.base import BaseEngineSpec
if TYPE_CHECKING:
from superset.models.core import Database # pylint: disable=unused-import
class BigQueryEngineSpec(BaseEngineSpec):
"""Engine spec for Google's BigQuery
@@ -69,7 +73,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
return None
@classmethod
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
data = super(BigQueryEngineSpec, cls).fetch_data(cursor, limit)
if data and type(data[0]).__name__ == "Row":
data = [r.values() for r in data] # type: ignore
@@ -112,7 +116,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
@classmethod
def extra_table_metadata(
cls, database, table_name: str, schema_name: str
cls, database: "Database", table_name: str, schema_name: str
) -> Dict[str, Any]:
indexes = database.get_indexes(table_name, schema_name)
if not indexes:
@@ -133,7 +137,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
}
@classmethod
def _get_fields(cls, cols):
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
"""
BigQuery dialect requires us to not use backtick in the fieldname which are
nested.
@@ -143,8 +147,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
column names in the result.
"""
return [
literal_column(c.get("name")).label(c.get("name").replace(".", "__"))
for c in cols
literal_column(c["name"]).label(c["name"].replace(".", "__")) for c in cols
]
@classmethod
@@ -156,7 +159,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
return "TIMESTAMP_MILLIS({col})"
@classmethod
def df_to_sql(cls, df: pd.DataFrame, **kwargs):
def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None:
"""
Upload data from a Pandas DataFrame to BigQuery. Calls
`DataFrame.to_gbq()` which requires `pandas_gbq` to be installed.