chore(pre-commit): Add pyupgrade and pycln hooks (#24197)

This commit is contained in:
John Bodley
2023-06-01 12:01:10 -07:00
committed by GitHub
parent 7d7ce63970
commit a4d5d7c6b9
448 changed files with 3084 additions and 3305 deletions

View File

@@ -23,19 +23,9 @@ import time
from abc import ABCMeta
from collections import defaultdict, deque
from datetime import datetime
from re import Pattern
from textwrap import dedent
from typing import (
Any,
cast,
Dict,
List,
Optional,
Pattern,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from typing import Any, cast, Optional, TYPE_CHECKING
from urllib import parse
import pandas as pd
@@ -78,7 +68,7 @@ if TYPE_CHECKING:
# need try/catch because pyhive may not be installed
try:
from pyhive.presto import Cursor # pylint: disable=unused-import
from pyhive.presto import Cursor
except ImportError:
pass
@@ -107,7 +97,7 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile(
logger = logging.getLogger(__name__)
def get_children(column: ResultSetColumnType) -> List[ResultSetColumnType]:
def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
"""
Get the children of a complex Presto type (row or array).
@@ -276,8 +266,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
) -> str | None:
"""
Convert a Python `datetime` object to a SQL expression.
:param target_type: The target type of expression
@@ -304,10 +294,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
connect_args: dict[str, Any],
catalog: str | None = None,
schema: str | None = None,
) -> tuple[URL, dict[str, Any]]:
database = uri.database
if schema and database:
schema = parse.quote(schema, safe="")
@@ -323,8 +313,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
connect_args: Dict[str, Any],
) -> Optional[str]:
connect_args: dict[str, Any],
) -> str | None:
"""
Return the configured schema.
@@ -341,7 +331,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return parse.unquote(database.split("/")[1])
@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param statement: A single SQL statement
@@ -369,8 +359,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def query_cost_formatter(
cls, raw_cost: List[Dict[str, Any]]
) -> List[Dict[str, str]]:
cls, raw_cost: list[dict[str, Any]]
) -> list[dict[str, str]]:
"""
Format cost estimate.
:param raw_cost: JSON estimate from Trino
@@ -401,7 +391,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
("networkCost", "Network cost", ""),
]
for row in raw_cost:
estimate: Dict[str, float] = row.get("estimate", {})
estimate: dict[str, float] = row.get("estimate", {})
statement_cost = {}
for key, label, suffix in columns:
if key in estimate:
@@ -412,7 +402,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
@cache_manager.data_cache.memoize()
def get_function_names(cls, database: Database) -> List[str]:
def get_function_names(cls, database: Database) -> list[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
@@ -426,12 +416,12 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unused-argument
cls,
table_name: str,
schema: Optional[str],
indexes: List[Dict[str, Any]],
schema: str | None,
indexes: list[dict[str, Any]],
database: Database,
limit: int = 0,
order_by: Optional[List[Tuple[str, bool]]] = None,
filters: Optional[Dict[Any, Any]] = None,
order_by: list[tuple[str, bool]] | None = None,
filters: dict[Any, Any] | None = None,
) -> str:
"""
Return a partition query.
@@ -449,7 +439,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
order
:param filters: dict of field name and filter value combinations
"""
limit_clause = "LIMIT {}".format(limit) if limit else ""
limit_clause = f"LIMIT {limit}" if limit else ""
order_by_clause = ""
if order_by:
l = []
@@ -492,11 +482,11 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def where_latest_partition( # pylint: disable=too-many-arguments
cls,
table_name: str,
schema: Optional[str],
schema: str | None,
database: Database,
query: Select,
columns: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Select]:
columns: list[dict[str, Any]] | None = None,
) -> Select | None:
try:
col_names, values = cls.latest_partition(
table_name, schema, database, show_first=True
@@ -525,7 +515,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return query
@classmethod
def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None:
if not df.empty:
return df.to_records(index=False)[0].item()
return None
@@ -535,10 +525,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def latest_partition(
cls,
table_name: str,
schema: Optional[str],
schema: str | None,
database: Database,
show_first: bool = False,
) -> Tuple[List[str], Optional[List[str]]]:
) -> tuple[list[str], list[str] | None]:
"""Returns col name and the latest (max) partition value for a table
:param table_name: the name of the table
@@ -589,7 +579,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def latest_sub_partition(
cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any
cls, table_name: str, schema: str | None, database: Database, **kwargs: Any
) -> Any:
"""Returns the latest (max) partition value for a table
@@ -652,7 +642,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
engine_name = "Presto"
allows_alias_to_source_column = False
custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
COLUMN_DOES_NOT_EXIST_REGEX: (
__(
'We can\'t seem to resolve the column "%(column_name)s" at '
@@ -708,16 +698,16 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
}
@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
version = extra.get("version")
return version is not None and Version(version) >= Version("0.319")
@classmethod
def update_impersonation_config(
cls,
connect_args: Dict[str, Any],
connect_args: dict[str, Any],
uri: str,
username: Optional[str],
username: str | None,
) -> None:
"""
Update a configuration dictionary
@@ -741,8 +731,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> Set[str]:
schema: str | None,
) -> set[str]:
"""
Get all the real table names within the specified schema.
@@ -769,8 +759,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> Set[str]:
schema: str | None,
) -> set[str]:
"""
Get all the view names within the specified schema.
@@ -817,7 +807,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
) -> List[str]:
) -> list[str]:
"""
Get all catalogs.
"""
@@ -826,7 +816,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def _create_column_info(
cls, name: str, data_type: types.TypeEngine
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Create column info object
:param name: column name
@@ -836,7 +826,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return {"name": name, "type": f"{data_type}"}
@classmethod
def _get_full_name(cls, names: List[Tuple[str, str]]) -> str:
def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
"""
Get the full column name
:param names: list of all individual column names
@@ -860,7 +850,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
)
@classmethod
def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]:
def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]:
"""
Split data type based on given delimiter. Do not split the string if the
delimiter is enclosed in quotes
@@ -869,16 +859,14 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
comma, whitespace)
:return: list of strings after breaking it by the delimiter
"""
return re.split(
r"{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)".format(delimiter), data_type
)
return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", data_type)
@classmethod
def _parse_structural_column( # pylint: disable=too-many-locals
cls,
parent_column_name: str,
parent_data_type: str,
result: List[Dict[str, Any]],
result: list[dict[str, Any]],
) -> None:
"""
Parse a row or array column
@@ -893,7 +881,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# split on open parenthesis ( to get the structural
# data type and its component types
data_types = cls._split_data_type(full_data_type, r"\(")
stack: List[Tuple[str, str]] = []
stack: list[tuple[str, str]] = []
for data_type in data_types:
# split on closed parenthesis ) to track which component
# types belong to what structural data type
@@ -962,8 +950,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def _show_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
) -> List[ResultRow]:
cls, inspector: Inspector, table_name: str, schema: str | None
) -> list[ResultRow]:
"""
Show presto column names
:param inspector: object that performs database schema inspection
@@ -974,13 +962,13 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
full_table = quote(table_name)
if schema:
full_table = "{}.{}".format(quote(schema), full_table)
full_table = f"{quote(schema)}.{full_table}"
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
) -> List[Dict[str, Any]]:
cls, inspector: Inspector, table_name: str, schema: str | None
) -> list[dict[str, Any]]:
"""
Get columns from a Presto data source. This includes handling row and
array data types
@@ -991,7 +979,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
(i.e. column name and data type)
"""
columns = cls._show_columns(inspector, table_name, schema)
result: List[Dict[str, Any]] = []
result: list[dict[str, Any]] = []
for column in columns:
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
@@ -1031,7 +1019,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return column_name.startswith('"') and column_name.endswith('"')
@classmethod
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
"""
Format column clauses where names are in quotes and labels are specified
:param cols: columns
@@ -1053,7 +1041,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# quote each column name if it is not already quoted
for index, col_name in enumerate(col_names):
if not cls._is_column_name_quoted(col_name):
col_names[index] = '"{}"'.format(col_name)
col_names[index] = f'"{col_name}"'
quoted_col_name = ".".join(
col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"'
for col_name in col_names
@@ -1069,12 +1057,12 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
database: Database,
table_name: str,
engine: Engine,
schema: Optional[str] = None,
schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
cols: Optional[List[Dict[str, Any]]] = None,
cols: list[dict[str, Any]] | None = None,
) -> str:
"""
Include selecting properties of row objects. We cannot easily break arrays into
@@ -1102,9 +1090,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def expand_data( # pylint: disable=too-many-locals
cls, columns: List[ResultSetColumnType], data: List[Dict[Any, Any]]
) -> Tuple[
List[ResultSetColumnType], List[Dict[Any, Any]], List[ResultSetColumnType]
cls, columns: list[ResultSetColumnType], data: list[dict[Any, Any]]
) -> tuple[
list[ResultSetColumnType], list[dict[Any, Any]], list[ResultSetColumnType]
]:
"""
We do not immediately display rows and arrays clearly in the data grid. This
@@ -1133,7 +1121,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# process each column, unnesting ARRAY types and
# expanding ROW types into new columns
to_process = deque((column, 0) for column in columns)
all_columns: List[ResultSetColumnType] = []
all_columns: list[ResultSetColumnType] = []
expanded_columns = []
current_array_level = None
while to_process:
@@ -1147,11 +1135,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# added by the first. every time we change a level in the nested arrays
# we reinitialize this.
if level != current_array_level:
unnested_rows: Dict[int, int] = defaultdict(int)
unnested_rows: dict[int, int] = defaultdict(int)
current_array_level = level
name = column["name"]
values: Optional[Union[str, List[Any]]]
values: str | list[Any] | None
if column["type"] and column["type"].startswith("ARRAY("):
# keep processing array children; we append to the right so that
@@ -1198,7 +1186,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
for row in data:
values = row.get(name) or []
if isinstance(values, str):
values = cast(Optional[List[Any]], destringify(values))
values = cast(Optional[list[Any]], destringify(values))
row[name] = values
for value, col in zip(values or [], expanded):
row[col["name"]] = value
@@ -1211,8 +1199,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def extra_table_metadata(
cls, database: Database, table_name: str, schema_name: Optional[str]
) -> Dict[str, Any]:
cls, database: Database, table_name: str, schema_name: str | None
) -> dict[str, Any]:
metadata = {}
if indexes := database.get_indexes(table_name, schema_name):
@@ -1243,8 +1231,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def get_create_view(
cls, database: Database, schema: Optional[str], table: str
) -> Optional[str]:
cls, database: Database, schema: str | None, table: str
) -> str | None:
"""
Return a CREATE VIEW statement, or `None` if not a view.
@@ -1267,7 +1255,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return rows[0][0]
@classmethod
def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
def get_tracking_url(cls, cursor: Cursor) -> str | None:
try:
if cursor.last_query_id:
# pylint: disable=protected-access, line-too-long
@@ -1277,7 +1265,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return None
@classmethod
def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
"""Updates progress information"""
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url