fix(type): improve type casting

This commit is contained in:
alexandrusoare
2026-05-12 16:09:13 +03:00
parent 752434ce9a
commit 794ed48e1e

View File

@@ -56,6 +56,7 @@ from superset.mcp_service.utils.oauth2_utils import (
OAUTH2_CONFIG_ERROR_MESSAGE,
)
from superset.mcp_service.utils.url_utils import get_superset_base_url
from superset.superset_typing import Column, Metric
logger = logging.getLogger(__name__)
@@ -145,7 +146,7 @@ class ChartLike(Protocol):
uuid: Any
def _build_query_columns(form_data: Dict[str, Any]) -> list[str]:
def _build_query_columns(form_data: Dict[str, Any]) -> list[Column]:
"""Build query columns list from form_data, including both x_axis and groupby.
Handles chart-type-specific keys:
@@ -153,42 +154,45 @@ def _build_query_columns(form_data: Dict[str, Any]) -> list[str]:
- Pivot tables: ``groupbyColumns`` + ``groupbyRows`` (when ``groupby`` is absent)
- Mixed timeseries: ``groupby_b`` (secondary groupby)
"""
x_axis_config = form_data.get("x_axis")
groupby_columns: list[Any] = form_data.get("groupby") or []
x_axis_config: Column | None = form_data.get("x_axis")
groupby_columns: list[Column] = form_data.get("groupby") or []
# Pivot tables store dimensions under groupbyColumns / groupbyRows
if not groupby_columns:
pivot_rows: list[Any] = form_data.get("groupbyRows") or []
pivot_cols: list[Any] = form_data.get("groupbyColumns") or []
pivot_rows: list[Column] = form_data.get("groupbyRows") or []
pivot_cols: list[Column] = form_data.get("groupbyColumns") or []
groupby_columns = list(pivot_rows) + list(pivot_cols)
# Mixed timeseries stores secondary groupby under groupby_b
groupby_b: list[Any] = form_data.get("groupby_b") or []
groupby_b: list[Column] = form_data.get("groupby_b") or []
for col in groupby_b:
if col not in groupby_columns:
groupby_columns.append(col)
# Deduplicate while preserving order
seen: set[Any] = set()
columns: list[Any] = []
seen: set[str] = set()
columns: list[Column] = []
def _add_unique(col: Column) -> None:
key = col if isinstance(col, str) else col.get("label", str(col))
if key not in seen:
columns.append(col)
seen.add(key)
if x_axis_config and isinstance(x_axis_config, str):
columns.append(x_axis_config)
seen.add(x_axis_config)
_add_unique(x_axis_config)
elif x_axis_config and isinstance(x_axis_config, dict):
col_name = x_axis_config.get("column_name")
if col_name:
columns.append(col_name)
seen.add(col_name)
if col_name and isinstance(col_name, str):
_add_unique(col_name)
for col in groupby_columns:
if col not in seen:
columns.append(col)
seen.add(col)
_add_unique(col)
return columns
def _build_query_metrics(form_data: Dict[str, Any]) -> list[Any]:
def _build_query_metrics(form_data: Dict[str, Any]) -> list[Metric]:
"""Extract metrics from form_data, handling chart-type variations.
Handles:
@@ -196,14 +200,14 @@ def _build_query_metrics(form_data: Dict[str, Any]) -> list[Any]:
- ``metric`` (singular) — Pie charts
- ``metrics_b`` — secondary y-axis in Mixed Timeseries charts
"""
metrics: list[Any] = list(form_data.get("metrics") or [])
metrics: list[Metric] = list(form_data.get("metrics") or [])
if not metrics:
singular = form_data.get("metric")
singular: Metric | None = form_data.get("metric")
if singular:
metrics = [singular]
# Mixed timeseries stores the second y-axis metrics under metrics_b
metrics_b: list[Any] = form_data.get("metrics_b") or []
metrics_b: list[Metric] = form_data.get("metrics_b") or []
for m in metrics_b:
if m not in metrics:
metrics.append(m)