chore(pre-commit): Add pyupgrade and pycln hooks (#24197)

This commit is contained in:
John Bodley
2023-06-01 12:01:10 -07:00
committed by GitHub
parent 7d7ce63970
commit a4d5d7c6b9
448 changed files with 3084 additions and 3305 deletions

View File

@@ -30,19 +30,7 @@ import re
from collections import defaultdict, OrderedDict
from datetime import date, datetime, timedelta
from itertools import product
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, cast, Optional, TYPE_CHECKING
import geohash
import numpy as np
@@ -124,7 +112,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
"""All visualizations derive this base class"""
viz_type: Optional[str] = None
viz_type: str | None = None
verbose_name = "Base Viz"
credits = ""
is_timeseries = False
@@ -134,8 +122,8 @@ class BaseViz: # pylint: disable=too-many-public-methods
@deprecated(deprecated_in="3.0")
def __init__(
self,
datasource: "BaseDatasource",
form_data: Dict[str, Any],
datasource: BaseDatasource,
form_data: dict[str, Any],
force: bool = False,
force_cached: bool = False,
) -> None:
@@ -150,25 +138,25 @@ class BaseViz: # pylint: disable=too-many-public-methods
self.query = ""
self.token = utils.get_form_data_token(form_data)
self.groupby: List[Column] = self.form_data.get("groupby") or []
self.groupby: list[Column] = self.form_data.get("groupby") or []
self.time_shift = timedelta()
self.status: Optional[str] = None
self.status: str | None = None
self.error_msg = ""
self.results: Optional[QueryResult] = None
self.applied_filter_columns: List[Column] = []
self.rejected_filter_columns: List[Column] = []
self.errors: List[Dict[str, Any]] = []
self.results: QueryResult | None = None
self.applied_filter_columns: list[Column] = []
self.rejected_filter_columns: list[Column] = []
self.errors: list[dict[str, Any]] = []
self.force = force
self._force_cached = force_cached
self.from_dttm: Optional[datetime] = None
self.to_dttm: Optional[datetime] = None
self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = []
self.from_dttm: datetime | None = None
self.to_dttm: datetime | None = None
self._extra_chart_data: list[tuple[str, pd.DataFrame]] = []
self.process_metrics()
self.applied_filters: List[Dict[str, str]] = []
self.rejected_filters: List[Dict[str, str]] = []
self.applied_filters: list[dict[str, str]] = []
self.rejected_filters: list[dict[str, str]] = []
@property
@deprecated(deprecated_in="3.0")
@@ -196,8 +184,8 @@ class BaseViz: # pylint: disable=too-many-public-methods
@staticmethod
@deprecated(deprecated_in="3.0")
def handle_js_int_overflow(
data: Dict[str, List[Dict[str, Any]]]
) -> Dict[str, List[Dict[str, Any]]]:
data: dict[str, list[dict[str, Any]]]
) -> dict[str, list[dict[str, Any]]]:
for record in data.get("records", {}):
for k, v in list(record.items()):
if isinstance(v, int):
@@ -259,7 +247,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
return df
@deprecated(deprecated_in="3.0")
def get_samples(self) -> Dict[str, Any]:
def get_samples(self) -> dict[str, Any]:
query_obj = self.query_obj()
query_obj.update(
{
@@ -281,7 +269,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
}
@deprecated(deprecated_in="3.0")
def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
def get_df(self, query_obj: QueryObjectDict | None = None) -> pd.DataFrame:
"""Returns a pandas dataframe based on the query object"""
if not query_obj:
query_obj = self.query_obj()
@@ -346,10 +334,10 @@ class BaseViz: # pylint: disable=too-many-public-methods
@staticmethod
@deprecated(deprecated_in="3.0")
def dedup_columns(*columns_args: Optional[List[Column]]) -> List[Column]:
def dedup_columns(*columns_args: list[Column] | None) -> list[Column]:
# dedup groupby and columns while preserving order
labels: List[str] = []
deduped_columns: List[Column] = []
labels: list[str] = []
deduped_columns: list[Column] = []
for columns in columns_args:
for column in columns or []:
label = get_column_name(column)
@@ -492,7 +480,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
return md5_sha_from_str(json_data)
@deprecated(deprecated_in="3.0")
def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload:
def get_payload(self, query_obj: QueryObjectDict | None = None) -> VizPayload:
"""Returns a payload of metadata and data"""
try:
@@ -534,8 +522,8 @@ class BaseViz: # pylint: disable=too-many-public-methods
@deprecated(deprecated_in="3.0")
def get_df_payload( # pylint: disable=too-many-statements
self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any
) -> Dict[str, Any]:
self, query_obj: QueryObjectDict | None = None, **kwargs: Any
) -> dict[str, Any]:
"""Handles caching around the df payload retrieval"""
if not query_obj:
query_obj = self.query_obj()
@@ -587,7 +575,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
)
+ get_column_names_from_columns(query_obj.get("groupby") or [])
+ utils.get_column_names_from_metrics(
cast(List[Metric], query_obj.get("metrics") or [])
cast(list[Metric], query_obj.get("metrics") or [])
)
if col not in self.datasource.column_names
]
@@ -676,12 +664,12 @@ class BaseViz: # pylint: disable=too-many-public-methods
)
@deprecated(deprecated_in="3.0")
def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
def payload_json_and_has_error(self, payload: VizPayload) -> tuple[str, bool]:
return self.json_dumps(payload), self.has_error(payload)
@property
@deprecated(deprecated_in="3.0")
def data(self) -> Dict[str, Any]:
def data(self) -> dict[str, Any]:
"""This is the data object serialized to the js layer"""
content = {
"form_data": self.form_data,
@@ -692,7 +680,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
return content
@deprecated(deprecated_in="3.0")
def get_csv(self) -> Optional[str]:
def get_csv(self) -> str | None:
df = self.get_df_payload()["df"] # leverage caching logic
include_index = not isinstance(df.index, pd.RangeIndex)
return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"])
@@ -766,8 +754,8 @@ class TableViz(BaseViz):
else QueryMode.AGGREGATE
)
columns: List[str] # output columns sans time and percent_metric column
percent_columns: List[str] = [] # percent columns that needs extra computation
columns: list[str] # output columns sans time and percent_metric column
percent_columns: list[str] = [] # percent columns that needs extra computation
if self.query_mode == QueryMode.RAW:
columns = get_metric_names(self.form_data.get("all_columns"))
@@ -906,7 +894,7 @@ class TimeTableViz(BaseViz):
return None
columns = None
values: Union[List[str], str] = self.metric_labels
values: list[str] | str = self.metric_labels
if self.form_data.get("groupby"):
values = self.metric_labels[0]
columns = get_column_names(self.form_data.get("groupby"))
@@ -948,10 +936,8 @@ class PivotTableViz(BaseViz):
if transpose and not columns:
raise QueryObjectValidationError(
_(
(
"Please choose at least one 'Columns' field when "
"select 'Transpose Pivot' option"
)
"Please choose at least one 'Columns' field when "
"select 'Transpose Pivot' option"
)
)
if not metrics:
@@ -973,8 +959,8 @@ class PivotTableViz(BaseViz):
@staticmethod
@deprecated(deprecated_in="3.0")
def get_aggfunc(
metric: str, df: pd.DataFrame, form_data: Dict[str, Any]
) -> Union[str, Callable[[Any], Any]]:
metric: str, df: pd.DataFrame, form_data: dict[str, Any]
) -> str | Callable[[Any], Any]:
aggfunc = form_data.get("pandas_aggfunc") or "sum"
if pd.api.types.is_numeric_dtype(df[metric]):
# Ensure that Pandas's sum function mimics that of SQL.
@@ -985,7 +971,7 @@ class PivotTableViz(BaseViz):
@staticmethod
@deprecated(deprecated_in="3.0")
def _format_datetime(value: Union[pd.Timestamp, datetime, date, str]) -> str:
def _format_datetime(value: pd.Timestamp | datetime | date | str) -> str:
"""
Format a timestamp in such a way that the viz will be able to apply
the correct formatting in the frontend.
@@ -994,7 +980,7 @@ class PivotTableViz(BaseViz):
:return: formatted timestamp if it is a valid timestamp, otherwise
the original value
"""
tstamp: Optional[pd.Timestamp] = None
tstamp: pd.Timestamp | None = None
if isinstance(value, pd.Timestamp):
tstamp = value
if isinstance(value, (date, datetime)):
@@ -1018,7 +1004,7 @@ class PivotTableViz(BaseViz):
del df[DTTM_ALIAS]
metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
aggfuncs: dict[str, str | Callable[[Any], Any]] = {}
for metric in metrics:
aggfuncs[metric] = self.get_aggfunc(metric, df, self.form_data)
@@ -1088,7 +1074,7 @@ class TreemapViz(BaseViz):
return query_obj
@deprecated(deprecated_in="3.0")
def _nest(self, metric: str, df: pd.DataFrame) -> List[Dict[str, Any]]:
def _nest(self, metric: str, df: pd.DataFrame) -> list[dict[str, Any]]:
nlevels = df.index.nlevels
if nlevels == 1:
result = [{"name": n, "value": v} for n, v in zip(df.index, df[metric])]
@@ -1200,7 +1186,7 @@ class NVD3Viz(BaseViz):
"""Base class for all nvd3 vizs"""
credits = '<a href="http://nvd3.org/">NVD3.org</a>'
viz_type: Optional[str] = None
viz_type: str | None = None
verbose_name = "Base NVD3 Viz"
is_timeseries = False
@@ -1249,7 +1235,7 @@ class BubbleViz(NVD3Viz):
df["shape"] = "circle"
df["group"] = df[[get_column_name(self.series)]] # type: ignore
series: Dict[Any, List[Any]] = defaultdict(list)
series: dict[Any, list[Any]] = defaultdict(list)
for row in df.to_dict(orient="records"):
series[row["group"]].append(row)
chart_data = []
@@ -1357,7 +1343,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
verbose_name = _("Time Series - Line Chart")
sort_series = False
is_timeseries = True
pivot_fill_value: Optional[int] = None
pivot_fill_value: int | None = None
@deprecated(deprecated_in="3.0")
def query_obj(self) -> QueryObjectDict:
@@ -1376,7 +1362,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
@deprecated(deprecated_in="3.0")
def to_series( # pylint: disable=too-many-branches
self, df: pd.DataFrame, classed: str = "", title_suffix: str = ""
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
cols = []
for col in df.columns:
if col == "":
@@ -1393,7 +1379,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
ys = series[name]
if df[name].dtype.kind not in "biufc":
continue
series_title: Union[List[str], str, Tuple[str, ...]]
series_title: list[str] | str | tuple[str, ...]
if isinstance(name, list):
series_title = [str(title) for title in name]
elif isinstance(name, tuple):
@@ -1510,7 +1496,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
dttm_series = df2[DTTM_ALIAS] + delta
df2 = df2.drop(DTTM_ALIAS, axis=1)
df2 = pd.concat([dttm_series, df2], axis=1)
label = "{} offset".format(option)
label = f"{option} offset"
df2 = self.process_data(df2)
self._extra_chart_data.append((label, df2))
@@ -1524,9 +1510,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
for i, (label, df2) in enumerate(self._extra_chart_data):
chart_data.extend(
self.to_series(
df2, classed="time-shift-{}".format(i), title_suffix=label
)
self.to_series(df2, classed=f"time-shift-{i}", title_suffix=label)
)
else:
chart_data = []
@@ -1547,16 +1531,14 @@ class NVD3TimeSeriesViz(NVD3Viz):
diff = df / df2
else:
raise QueryObjectValidationError(
"Invalid `comparison_type`: {0}".format(comparison_type)
f"Invalid `comparison_type`: {comparison_type}"
)
# remove leading/trailing NaNs from the time shift difference
diff = diff[diff.first_valid_index() : diff.last_valid_index()]
chart_data.extend(
self.to_series(
diff, classed="time-shift-{}".format(i), title_suffix=label
)
self.to_series(diff, classed=f"time-shift-{i}", title_suffix=label)
)
if not self.sort_series:
@@ -1670,7 +1652,7 @@ class NVD3DualLineViz(NVD3Viz):
return query_obj
@deprecated(deprecated_in="3.0")
def to_series(self, df: pd.DataFrame, classed: str = "") -> List[Dict[str, Any]]:
def to_series(self, df: pd.DataFrame, classed: str = "") -> list[dict[str, Any]]:
cols = []
for col in df.columns:
if col == "":
@@ -1823,7 +1805,7 @@ class HistogramViz(BaseViz):
return query_obj
@deprecated(deprecated_in="3.0")
def labelify(self, keys: Union[List[str], str], column: str) -> str:
def labelify(self, keys: list[str] | str, column: str) -> str:
if isinstance(keys, str):
keys = [keys]
# removing undesirable characters
@@ -2033,17 +2015,17 @@ class SankeyViz(BaseViz):
df["target"] = df["target"].astype(str)
recs = df.to_dict(orient="records")
hierarchy: Dict[str, Set[str]] = defaultdict(set)
hierarchy: dict[str, set[str]] = defaultdict(set)
for row in recs:
hierarchy[row["source"]].add(row["target"])
@deprecated(deprecated_in="3.0")
def find_cycle(graph: Dict[str, Set[str]]) -> Optional[Tuple[str, str]]:
def find_cycle(graph: dict[str, set[str]]) -> tuple[str, str] | None:
"""Whether there's a cycle in a directed graph"""
path = set()
@deprecated(deprecated_in="3.0")
def visit(vertex: str) -> Optional[Tuple[str, str]]:
def visit(vertex: str) -> tuple[str, str] | None:
path.add(vertex)
for neighbour in graph.get(vertex, ()):
if neighbour in path or visit(neighbour):
@@ -2214,7 +2196,7 @@ class FilterBoxViz(BaseViz):
"""A multi filter, multi-choice filter box to make dashboards interactive"""
query_context_factory: Optional[QueryContextFactory] = None
query_context_factory: QueryContextFactory | None = None
viz_type = "filter_box"
verbose_name = _("Filters")
is_timeseries = False
@@ -2581,20 +2563,20 @@ class BaseDeckGLViz(BaseViz):
is_timeseries = False
credits = '<a href="https://uber.github.io/deck.gl/">deck.gl</a>'
spatial_control_keys: List[str] = []
spatial_control_keys: list[str] = []
@deprecated(deprecated_in="3.0")
def get_metrics(self) -> List[str]:
def get_metrics(self) -> list[str]:
# pylint: disable=attribute-defined-outside-init
self.metric = self.form_data.get("size")
return [self.metric] if self.metric else []
@deprecated(deprecated_in="3.0")
def process_spatial_query_obj(self, key: str, group_by: List[str]) -> None:
def process_spatial_query_obj(self, key: str, group_by: list[str]) -> None:
group_by.extend(self.get_spatial_columns(key))
@deprecated(deprecated_in="3.0")
def get_spatial_columns(self, key: str) -> List[str]:
def get_spatial_columns(self, key: str) -> list[str]:
spatial = self.form_data.get(key)
if spatial is None:
raise ValueError(_("Bad spatial key"))
@@ -2611,7 +2593,7 @@ class BaseDeckGLViz(BaseViz):
@staticmethod
@deprecated(deprecated_in="3.0")
def parse_coordinates(latlog: Any) -> Optional[Tuple[float, float]]:
def parse_coordinates(latlog: Any) -> tuple[float, float] | None:
if not latlog:
return None
try:
@@ -2624,7 +2606,7 @@ class BaseDeckGLViz(BaseViz):
@staticmethod
@deprecated(deprecated_in="3.0")
def reverse_geohash_decode(geohash_code: str) -> Tuple[str, str]:
def reverse_geohash_decode(geohash_code: str) -> tuple[str, str]:
lat, lng = geohash.decode(geohash_code)
return (lng, lat)
@@ -2692,7 +2674,7 @@ class BaseDeckGLViz(BaseViz):
self.add_null_filters()
query_obj = super().query_obj()
group_by: List[str] = []
group_by: list[str] = []
for key in self.spatial_control_keys:
self.process_spatial_query_obj(key, group_by)
@@ -2720,7 +2702,7 @@ class BaseDeckGLViz(BaseViz):
return query_obj
@deprecated(deprecated_in="3.0")
def get_js_columns(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_js_columns(self, data: dict[str, Any]) -> dict[str, Any]:
cols = self.form_data.get("js_columns") or []
return {col: data.get(col) for col in cols}
@@ -2748,7 +2730,7 @@ class BaseDeckGLViz(BaseViz):
}
@deprecated(deprecated_in="3.0")
def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
raise NotImplementedError()
@@ -2774,7 +2756,7 @@ class DeckScatterViz(BaseDeckGLViz):
return super().query_obj()
@deprecated(deprecated_in="3.0")
def get_metrics(self) -> List[str]:
def get_metrics(self) -> list[str]:
# pylint: disable=attribute-defined-outside-init
self.metric = None
if self.point_radius_fixed.get("type") == "metric":
@@ -2783,7 +2765,7 @@ class DeckScatterViz(BaseDeckGLViz):
return []
@deprecated(deprecated_in="3.0")
def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"metric": data.get(self.metric_label) if self.metric_label else None,
"radius": self.fixed_value
@@ -2825,7 +2807,7 @@ class DeckScreengrid(BaseDeckGLViz):
return super().query_obj()
@deprecated(deprecated_in="3.0")
def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"position": data.get("spatial"),
"weight": (data.get(self.metric_label) if self.metric_label else None) or 1,
@@ -2849,7 +2831,7 @@ class DeckGrid(BaseDeckGLViz):
spatial_control_keys = ["spatial"]
@deprecated(deprecated_in="3.0")
def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"position": data.get("spatial"),
"weight": (data.get(self.metric_label) if self.metric_label else None) or 1,
@@ -2864,7 +2846,7 @@ class DeckGrid(BaseDeckGLViz):
@deprecated(deprecated_in="3.0")
def geohash_to_json(geohash_code: str) -> List[List[float]]:
def geohash_to_json(geohash_code: str) -> list[list[float]]:
bbox = geohash.bbox(geohash_code)
return [
[bbox.get("w"), bbox.get("n")],
@@ -2907,7 +2889,7 @@ class DeckPathViz(BaseDeckGLViz):
return query_obj
@deprecated(deprecated_in="3.0")
def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
line_type = self.form_data["line_type"]
deser = self.deser_map[line_type]
line_column = self.form_data["line_column"]
@@ -2946,14 +2928,14 @@ class DeckPolygon(DeckPathViz):
return super().query_obj()
@deprecated(deprecated_in="3.0")
def get_metrics(self) -> List[str]:
def get_metrics(self) -> list[str]:
metrics = [self.form_data.get("metric")]
if self.elevation.get("type") == "metric":
metrics.append(self.elevation.get("value"))
return [metric for metric in metrics if metric]
@deprecated(deprecated_in="3.0")
def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
super().get_properties(data)
elevation = self.form_data["point_radius_fixed"]["value"]
type_ = self.form_data["point_radius_fixed"]["type"]
@@ -2974,7 +2956,7 @@ class DeckHex(BaseDeckGLViz):
spatial_control_keys = ["spatial"]
@deprecated(deprecated_in="3.0")
def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"position": data.get("spatial"),
"weight": (data.get(self.metric_label) if self.metric_label else None) or 1,
@@ -2996,7 +2978,7 @@ class DeckHeatmap(BaseDeckGLViz):
verbose_name = _("Deck.gl - Heatmap")
spatial_control_keys = ["spatial"]
def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"position": data.get("spatial"),
"weight": (data.get(self.metric_label) if self.metric_label else None) or 1,
@@ -3025,7 +3007,7 @@ class DeckGeoJson(BaseDeckGLViz):
return query_obj
@deprecated(deprecated_in="3.0")
def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
geojson = data[get_column_name(self.form_data["geojson"])]
return json.loads(geojson)
@@ -3047,7 +3029,7 @@ class DeckArc(BaseDeckGLViz):
return super().query_obj()
@deprecated(deprecated_in="3.0")
def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
dim = self.form_data.get("dimension")
return {
"sourcePosition": data.get("start_spatial"),
@@ -3153,7 +3135,7 @@ class PairedTTestViz(BaseViz):
else:
cols.append(col)
df.columns = cols
data: Dict[str, List[Dict[str, Any]]] = {}
data: dict[str, list[dict[str, Any]]] = {}
series = df.to_dict("series")
for name_set in df.columns:
# If no groups are defined, nameSet will be the metric name
@@ -3188,7 +3170,7 @@ class RoseViz(NVD3TimeSeriesViz):
return None
data = super().get_data(df)
result: Dict[str, List[Dict[str, str]]] = {}
result: dict[str, list[dict[str, str]]] = {}
for datum in data:
key = datum["key"]
for val in datum["values"]:
@@ -3227,8 +3209,8 @@ class PartitionViz(NVD3TimeSeriesViz):
@staticmethod
@deprecated(deprecated_in="3.0")
def levels_for(
time_op: str, groups: List[str], df: pd.DataFrame
) -> Dict[int, pd.Series]:
time_op: str, groups: list[str], df: pd.DataFrame
) -> dict[int, pd.Series]:
"""
Compute the partition at each `level` from the dataframe.
"""
@@ -3245,8 +3227,8 @@ class PartitionViz(NVD3TimeSeriesViz):
@staticmethod
@deprecated(deprecated_in="3.0")
def levels_for_diff(
time_op: str, groups: List[str], df: pd.DataFrame
) -> Dict[int, pd.DataFrame]:
time_op: str, groups: list[str], df: pd.DataFrame
) -> dict[int, pd.DataFrame]:
# Obtain a unique list of the time grains
times = list(set(df[DTTM_ALIAS]))
times.sort()
@@ -3282,8 +3264,8 @@ class PartitionViz(NVD3TimeSeriesViz):
@deprecated(deprecated_in="3.0")
def levels_for_time(
self, groups: List[str], df: pd.DataFrame
) -> Dict[int, VizData]:
self, groups: list[str], df: pd.DataFrame
) -> dict[int, VizData]:
procs = {}
for i in range(0, len(groups) + 1):
self.form_data["groupby"] = groups[:i]
@@ -3295,11 +3277,11 @@ class PartitionViz(NVD3TimeSeriesViz):
@deprecated(deprecated_in="3.0")
def nest_values(
self,
levels: Dict[int, pd.DataFrame],
levels: dict[int, pd.DataFrame],
level: int = 0,
metric: Optional[str] = None,
dims: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
metric: str | None = None,
dims: list[str] | None = None,
) -> list[dict[str, Any]]:
"""
Nest values at each level on the back-end with
access and setting, instead of summing from the bottom.
@@ -3340,11 +3322,11 @@ class PartitionViz(NVD3TimeSeriesViz):
@deprecated(deprecated_in="3.0")
def nest_procs(
self,
procs: Dict[int, pd.DataFrame],
procs: dict[int, pd.DataFrame],
level: int = -1,
dims: Optional[Tuple[str, ...]] = None,
dims: tuple[str, ...] | None = None,
time: Any = None,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
if dims is None:
dims = ()
if level == -1:
@@ -3395,7 +3377,7 @@ class PartitionViz(NVD3TimeSeriesViz):
@deprecated(deprecated_in="3.0")
def get_subclasses(cls: Type[BaseViz]) -> Set[Type[BaseViz]]:
def get_subclasses(cls: type[BaseViz]) -> set[type[BaseViz]]:
return set(cls.__subclasses__()).union(
[sc for c in cls.__subclasses__() for sc in get_subclasses(c)]
)