mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
chore(pre-commit): Add pyupgrade and pycln hooks (#24197)
This commit is contained in:
@@ -23,19 +23,9 @@ import time
|
||||
from abc import ABCMeta
|
||||
from collections import defaultdict, deque
|
||||
from datetime import datetime
|
||||
from re import Pattern
|
||||
from textwrap import dedent
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, cast, Optional, TYPE_CHECKING
|
||||
from urllib import parse
|
||||
|
||||
import pandas as pd
|
||||
@@ -78,7 +68,7 @@ if TYPE_CHECKING:
|
||||
|
||||
# need try/catch because pyhive may not be installed
|
||||
try:
|
||||
from pyhive.presto import Cursor # pylint: disable=unused-import
|
||||
from pyhive.presto import Cursor
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -107,7 +97,7 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_children(column: ResultSetColumnType) -> List[ResultSetColumnType]:
|
||||
def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
|
||||
"""
|
||||
Get the children of a complex Presto type (row or array).
|
||||
|
||||
@@ -276,8 +266,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(
|
||||
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
|
||||
) -> str | None:
|
||||
"""
|
||||
Convert a Python `datetime` object to a SQL expression.
|
||||
:param target_type: The target type of expression
|
||||
@@ -304,10 +294,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
def adjust_engine_params(
|
||||
cls,
|
||||
uri: URL,
|
||||
connect_args: Dict[str, Any],
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Tuple[URL, Dict[str, Any]]:
|
||||
connect_args: dict[str, Any],
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> tuple[URL, dict[str, Any]]:
|
||||
database = uri.database
|
||||
if schema and database:
|
||||
schema = parse.quote(schema, safe="")
|
||||
@@ -323,8 +313,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
def get_schema_from_engine_params(
|
||||
cls,
|
||||
sqlalchemy_uri: URL,
|
||||
connect_args: Dict[str, Any],
|
||||
) -> Optional[str]:
|
||||
connect_args: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""
|
||||
Return the configured schema.
|
||||
|
||||
@@ -341,7 +331,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
return parse.unquote(database.split("/")[1])
|
||||
|
||||
@classmethod
|
||||
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
|
||||
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Run a SQL query that estimates the cost of a given statement.
|
||||
:param statement: A single SQL statement
|
||||
@@ -369,8 +359,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
|
||||
@classmethod
|
||||
def query_cost_formatter(
|
||||
cls, raw_cost: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, str]]:
|
||||
cls, raw_cost: list[dict[str, Any]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""
|
||||
Format cost estimate.
|
||||
:param raw_cost: JSON estimate from Trino
|
||||
@@ -401,7 +391,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
("networkCost", "Network cost", ""),
|
||||
]
|
||||
for row in raw_cost:
|
||||
estimate: Dict[str, float] = row.get("estimate", {})
|
||||
estimate: dict[str, float] = row.get("estimate", {})
|
||||
statement_cost = {}
|
||||
for key, label, suffix in columns:
|
||||
if key in estimate:
|
||||
@@ -412,7 +402,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
|
||||
@classmethod
|
||||
@cache_manager.data_cache.memoize()
|
||||
def get_function_names(cls, database: Database) -> List[str]:
|
||||
def get_function_names(cls, database: Database) -> list[str]:
|
||||
"""
|
||||
Get a list of function names that are able to be called on the database.
|
||||
Used for SQL Lab autocomplete.
|
||||
@@ -426,12 +416,12 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unused-argument
|
||||
cls,
|
||||
table_name: str,
|
||||
schema: Optional[str],
|
||||
indexes: List[Dict[str, Any]],
|
||||
schema: str | None,
|
||||
indexes: list[dict[str, Any]],
|
||||
database: Database,
|
||||
limit: int = 0,
|
||||
order_by: Optional[List[Tuple[str, bool]]] = None,
|
||||
filters: Optional[Dict[Any, Any]] = None,
|
||||
order_by: list[tuple[str, bool]] | None = None,
|
||||
filters: dict[Any, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return a partition query.
|
||||
@@ -449,7 +439,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
order
|
||||
:param filters: dict of field name and filter value combinations
|
||||
"""
|
||||
limit_clause = "LIMIT {}".format(limit) if limit else ""
|
||||
limit_clause = f"LIMIT {limit}" if limit else ""
|
||||
order_by_clause = ""
|
||||
if order_by:
|
||||
l = []
|
||||
@@ -492,11 +482,11 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
def where_latest_partition( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
table_name: str,
|
||||
schema: Optional[str],
|
||||
schema: str | None,
|
||||
database: Database,
|
||||
query: Select,
|
||||
columns: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Optional[Select]:
|
||||
columns: list[dict[str, Any]] | None = None,
|
||||
) -> Select | None:
|
||||
try:
|
||||
col_names, values = cls.latest_partition(
|
||||
table_name, schema, database, show_first=True
|
||||
@@ -525,7 +515,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
|
||||
def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None:
|
||||
if not df.empty:
|
||||
return df.to_records(index=False)[0].item()
|
||||
return None
|
||||
@@ -535,10 +525,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
def latest_partition(
|
||||
cls,
|
||||
table_name: str,
|
||||
schema: Optional[str],
|
||||
schema: str | None,
|
||||
database: Database,
|
||||
show_first: bool = False,
|
||||
) -> Tuple[List[str], Optional[List[str]]]:
|
||||
) -> tuple[list[str], list[str] | None]:
|
||||
"""Returns col name and the latest (max) partition value for a table
|
||||
|
||||
:param table_name: the name of the table
|
||||
@@ -589,7 +579,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
|
||||
@classmethod
|
||||
def latest_sub_partition(
|
||||
cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any
|
||||
cls, table_name: str, schema: str | None, database: Database, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Returns the latest (max) partition value for a table
|
||||
|
||||
@@ -652,7 +642,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
engine_name = "Presto"
|
||||
allows_alias_to_source_column = False
|
||||
|
||||
custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
|
||||
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
|
||||
COLUMN_DOES_NOT_EXIST_REGEX: (
|
||||
__(
|
||||
'We can\'t seem to resolve the column "%(column_name)s" at '
|
||||
@@ -708,16 +698,16 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
|
||||
def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
|
||||
version = extra.get("version")
|
||||
return version is not None and Version(version) >= Version("0.319")
|
||||
|
||||
@classmethod
|
||||
def update_impersonation_config(
|
||||
cls,
|
||||
connect_args: Dict[str, Any],
|
||||
connect_args: dict[str, Any],
|
||||
uri: str,
|
||||
username: Optional[str],
|
||||
username: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Update a configuration dictionary
|
||||
@@ -741,8 +731,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
schema: Optional[str],
|
||||
) -> Set[str]:
|
||||
schema: str | None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all the real table names within the specified schema.
|
||||
|
||||
@@ -769,8 +759,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
schema: Optional[str],
|
||||
) -> Set[str]:
|
||||
schema: str | None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all the view names within the specified schema.
|
||||
|
||||
@@ -817,7 +807,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get all catalogs.
|
||||
"""
|
||||
@@ -826,7 +816,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
@classmethod
|
||||
def _create_column_info(
|
||||
cls, name: str, data_type: types.TypeEngine
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create column info object
|
||||
:param name: column name
|
||||
@@ -836,7 +826,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
return {"name": name, "type": f"{data_type}"}
|
||||
|
||||
@classmethod
|
||||
def _get_full_name(cls, names: List[Tuple[str, str]]) -> str:
|
||||
def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
|
||||
"""
|
||||
Get the full column name
|
||||
:param names: list of all individual column names
|
||||
@@ -860,7 +850,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]:
|
||||
def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]:
|
||||
"""
|
||||
Split data type based on given delimiter. Do not split the string if the
|
||||
delimiter is enclosed in quotes
|
||||
@@ -869,16 +859,14 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
comma, whitespace)
|
||||
:return: list of strings after breaking it by the delimiter
|
||||
"""
|
||||
return re.split(
|
||||
r"{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)".format(delimiter), data_type
|
||||
)
|
||||
return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", data_type)
|
||||
|
||||
@classmethod
|
||||
def _parse_structural_column( # pylint: disable=too-many-locals
|
||||
cls,
|
||||
parent_column_name: str,
|
||||
parent_data_type: str,
|
||||
result: List[Dict[str, Any]],
|
||||
result: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Parse a row or array column
|
||||
@@ -893,7 +881,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
# split on open parenthesis ( to get the structural
|
||||
# data type and its component types
|
||||
data_types = cls._split_data_type(full_data_type, r"\(")
|
||||
stack: List[Tuple[str, str]] = []
|
||||
stack: list[tuple[str, str]] = []
|
||||
for data_type in data_types:
|
||||
# split on closed parenthesis ) to track which component
|
||||
# types belong to what structural data type
|
||||
@@ -962,8 +950,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
|
||||
@classmethod
|
||||
def _show_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: Optional[str]
|
||||
) -> List[ResultRow]:
|
||||
cls, inspector: Inspector, table_name: str, schema: str | None
|
||||
) -> list[ResultRow]:
|
||||
"""
|
||||
Show presto column names
|
||||
:param inspector: object that performs database schema inspection
|
||||
@@ -974,13 +962,13 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
|
||||
full_table = quote(table_name)
|
||||
if schema:
|
||||
full_table = "{}.{}".format(quote(schema), full_table)
|
||||
full_table = f"{quote(schema)}.{full_table}"
|
||||
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
|
||||
|
||||
@classmethod
|
||||
def get_columns(
|
||||
cls, inspector: Inspector, table_name: str, schema: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
cls, inspector: Inspector, table_name: str, schema: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get columns from a Presto data source. This includes handling row and
|
||||
array data types
|
||||
@@ -991,7 +979,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
(i.e. column name and data type)
|
||||
"""
|
||||
columns = cls._show_columns(inspector, table_name, schema)
|
||||
result: List[Dict[str, Any]] = []
|
||||
result: list[dict[str, Any]] = []
|
||||
for column in columns:
|
||||
# parse column if it is a row or array
|
||||
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
|
||||
@@ -1031,7 +1019,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
return column_name.startswith('"') and column_name.endswith('"')
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
|
||||
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
|
||||
"""
|
||||
Format column clauses where names are in quotes and labels are specified
|
||||
:param cols: columns
|
||||
@@ -1053,7 +1041,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
# quote each column name if it is not already quoted
|
||||
for index, col_name in enumerate(col_names):
|
||||
if not cls._is_column_name_quoted(col_name):
|
||||
col_names[index] = '"{}"'.format(col_name)
|
||||
col_names[index] = f'"{col_name}"'
|
||||
quoted_col_name = ".".join(
|
||||
col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"'
|
||||
for col_name in col_names
|
||||
@@ -1069,12 +1057,12 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
database: Database,
|
||||
table_name: str,
|
||||
engine: Engine,
|
||||
schema: Optional[str] = None,
|
||||
schema: str | None = None,
|
||||
limit: int = 100,
|
||||
show_cols: bool = False,
|
||||
indent: bool = True,
|
||||
latest_partition: bool = True,
|
||||
cols: Optional[List[Dict[str, Any]]] = None,
|
||||
cols: list[dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Include selecting properties of row objects. We cannot easily break arrays into
|
||||
@@ -1102,9 +1090,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
|
||||
@classmethod
|
||||
def expand_data( # pylint: disable=too-many-locals
|
||||
cls, columns: List[ResultSetColumnType], data: List[Dict[Any, Any]]
|
||||
) -> Tuple[
|
||||
List[ResultSetColumnType], List[Dict[Any, Any]], List[ResultSetColumnType]
|
||||
cls, columns: list[ResultSetColumnType], data: list[dict[Any, Any]]
|
||||
) -> tuple[
|
||||
list[ResultSetColumnType], list[dict[Any, Any]], list[ResultSetColumnType]
|
||||
]:
|
||||
"""
|
||||
We do not immediately display rows and arrays clearly in the data grid. This
|
||||
@@ -1133,7 +1121,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
# process each column, unnesting ARRAY types and
|
||||
# expanding ROW types into new columns
|
||||
to_process = deque((column, 0) for column in columns)
|
||||
all_columns: List[ResultSetColumnType] = []
|
||||
all_columns: list[ResultSetColumnType] = []
|
||||
expanded_columns = []
|
||||
current_array_level = None
|
||||
while to_process:
|
||||
@@ -1147,11 +1135,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
# added by the first. every time we change a level in the nested arrays
|
||||
# we reinitialize this.
|
||||
if level != current_array_level:
|
||||
unnested_rows: Dict[int, int] = defaultdict(int)
|
||||
unnested_rows: dict[int, int] = defaultdict(int)
|
||||
current_array_level = level
|
||||
|
||||
name = column["name"]
|
||||
values: Optional[Union[str, List[Any]]]
|
||||
values: str | list[Any] | None
|
||||
|
||||
if column["type"] and column["type"].startswith("ARRAY("):
|
||||
# keep processing array children; we append to the right so that
|
||||
@@ -1198,7 +1186,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
for row in data:
|
||||
values = row.get(name) or []
|
||||
if isinstance(values, str):
|
||||
values = cast(Optional[List[Any]], destringify(values))
|
||||
values = cast(Optional[list[Any]], destringify(values))
|
||||
row[name] = values
|
||||
for value, col in zip(values or [], expanded):
|
||||
row[col["name"]] = value
|
||||
@@ -1211,8 +1199,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
|
||||
@classmethod
|
||||
def extra_table_metadata(
|
||||
cls, database: Database, table_name: str, schema_name: Optional[str]
|
||||
) -> Dict[str, Any]:
|
||||
cls, database: Database, table_name: str, schema_name: str | None
|
||||
) -> dict[str, Any]:
|
||||
metadata = {}
|
||||
|
||||
if indexes := database.get_indexes(table_name, schema_name):
|
||||
@@ -1243,8 +1231,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
|
||||
@classmethod
|
||||
def get_create_view(
|
||||
cls, database: Database, schema: Optional[str], table: str
|
||||
) -> Optional[str]:
|
||||
cls, database: Database, schema: str | None, table: str
|
||||
) -> str | None:
|
||||
"""
|
||||
Return a CREATE VIEW statement, or `None` if not a view.
|
||||
|
||||
@@ -1267,7 +1255,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
return rows[0][0]
|
||||
|
||||
@classmethod
|
||||
def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
|
||||
def get_tracking_url(cls, cursor: Cursor) -> str | None:
|
||||
try:
|
||||
if cursor.last_query_id:
|
||||
# pylint: disable=protected-access, line-too-long
|
||||
@@ -1277,7 +1265,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
|
||||
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
|
||||
"""Updates progress information"""
|
||||
if tracking_url := cls.get_tracking_url(cursor):
|
||||
query.tracking_url = tracking_url
|
||||
|
||||
Reference in New Issue
Block a user