chore: improve types (#36367)

This commit is contained in:
Beto Dealmeida
2025-12-04 13:51:35 -05:00
committed by GitHub
parent 16e6452b8c
commit 482c674a0f
8 changed files with 83 additions and 78 deletions

View File

@@ -37,7 +37,7 @@ from superset.exceptions import SupersetException
from superset.explore.exceptions import WrongEndpointError
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.extensions import security_manager
from superset.superset_typing import BaseDatasourceData, QueryData
from superset.superset_typing import ExplorableData
from superset.utils import core as utils, json
from superset.views.utils import (
get_datasource_info,
@@ -136,7 +136,7 @@ class GetExploreCommand(BaseCommand, ABC):
utils.merge_extra_filters(form_data)
utils.merge_request_params(form_data, request.args)
datasource_data: BaseDatasourceData | QueryData = {
datasource_data: ExplorableData = {
"type": self._datasource_type or "unknown",
"name": datasource_name,
"columns": [],

View File

@@ -85,6 +85,7 @@ from superset.exceptions import (
SupersetSecurityException,
SupersetSyntaxErrorException,
)
from superset.explorables.base import TimeGrainDict
from superset.jinja_context import (
BaseTemplateProcessor,
ExtraCache,
@@ -105,7 +106,7 @@ from superset.sql.parse import Table
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
BaseDatasourceData,
ExplorableData,
Metric,
QueryObjectDict,
ResultSetColumnType,
@@ -265,7 +266,7 @@ class BaseDatasource(
# Check if all requested columns are drillable
return set(column_names).issubset(drillable_columns)
def get_time_grains(self) -> list[dict[str, Any]]:
def get_time_grains(self) -> list[TimeGrainDict]:
"""
Get available time granularities from the database.
@@ -435,7 +436,7 @@ class BaseDatasource(
return verb_map
@property
def data(self) -> BaseDatasourceData:
def data(self) -> ExplorableData:
"""Data representation of the datasource sent to the frontend"""
return {
# simple fields
@@ -1441,7 +1442,7 @@ class SqlaTable(
return [(g.duration, g.name) for g in self.database.grains() or []]
@property
def data(self) -> BaseDatasourceData:
def data(self) -> ExplorableData:
data_ = super().data
if self.type == "table":
data_["granularity_sqla"] = self.granularity_sqla

View File

@@ -25,11 +25,32 @@ from __future__ import annotations
from collections.abc import Hashable
from datetime import datetime
from typing import Any, Protocol, runtime_checkable
from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING, TypedDict
from superset.common.query_object import QueryObject
from superset.models.helpers import QueryResult
from superset.superset_typing import QueryObjectDict
if TYPE_CHECKING:
from superset.common.query_object import QueryObject
from superset.models.helpers import QueryResult
from superset.superset_typing import ExplorableData, QueryObjectDict
class TimeGrainDict(TypedDict):
"""
TypedDict for time grain options returned by get_time_grains.
Represents a time granularity option that can be used for grouping
temporal data. Each time grain specifies how to bucket timestamps.
Attributes:
name: Display name for the time grain (e.g., "Hour", "Day", "Week")
function: Implementation-specific expression for applying the grain.
For SQL datasources, this is typically a SQL expression template
like "DATE_TRUNC('hour', {col})".
duration: ISO 8601 duration string (e.g., "PT1H", "P1D", "P1W")
"""
name: str
function: str
duration: str | None
@runtime_checkable
@@ -152,9 +173,8 @@ class Explorable(Protocol):
:return: List of column name strings
"""
# TODO: use TypedDict for return type
@property
def data(self) -> dict[str, Any]:
def data(self) -> ExplorableData:
"""
Full metadata representation sent to the frontend.
@@ -257,7 +277,7 @@ class Explorable(Protocol):
# Time Granularity
# =========================================================================
def get_time_grains(self) -> list[dict[str, Any]]:
def get_time_grains(self) -> list[TimeGrainDict]:
"""
Get available time granularities for temporal grouping.

View File

@@ -50,6 +50,7 @@ from superset_core.api.models import Query as CoreQuery, SavedQuery as CoreSaved
from superset import security_manager
from superset.exceptions import SupersetParseError, SupersetSecurityException
from superset.explorables.base import TimeGrainDict
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
from superset.models.helpers import (
AuditMixinNullable,
@@ -63,7 +64,7 @@ from superset.sql.parse import (
Table,
)
from superset.sqllab.limiting_factor import LimitingFactor
from superset.superset_typing import QueryData, QueryObjectDict
from superset.superset_typing import ExplorableData, QueryObjectDict
from superset.utils import json
from superset.utils.core import (
get_column_name,
@@ -239,7 +240,7 @@ class Query(
return None
@property
def data(self) -> QueryData:
def data(self) -> ExplorableData:
"""Returns query data for the frontend"""
order_by_choices = []
for col in self.columns:
@@ -335,6 +336,32 @@ class Query(
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
return []
def get_time_grains(self) -> list[TimeGrainDict]:
"""
Get available time granularities from the database.
Delegates to the database's time grain definitions.
"""
return [
{
"name": grain.name,
"function": grain.function,
"duration": grain.duration,
}
for grain in (self.database.grains() or [])
]
def has_drill_by_columns(self, column_names: list[str]) -> bool:
"""
Check if the specified columns support drill-by operations.
For Query objects, all columns are considered drillable since they
come from ad-hoc SQL queries without predefined metadata.
"""
if not column_names:
return False
return set(column_names).issubset(set(self.column_names))
@property
def tracking_url(self) -> Optional[str]:
"""

View File

@@ -196,15 +196,19 @@ class QueryObjectDict(TypedDict, total=False):
timeseries_limit_metric: Metric | None
class BaseDatasourceData(TypedDict, total=False):
class ExplorableData(TypedDict, total=False):
"""
TypedDict for datasource data returned to the frontend.
TypedDict for explorable data returned to the frontend.
This represents the structure of the dictionary returned from BaseDatasource.data
property. It provides datasource information to the frontend for visualization
and querying.
This represents the structure of the dictionary returned from the `data` property
of any Explorable (BaseDatasource, Query, etc.). It provides datasource/query
information to the frontend for visualization and querying.
Core fields from BaseDatasource.data:
All fields are optional (total=False) since different explorable types provide
different subsets of these fields. Query objects provide a minimal subset while
SqlaTable provides the full set.
Core fields:
id: Unique identifier for the datasource
uid: Unique identifier including type (e.g., "1__table")
column_formats: D3 format strings for columns
@@ -292,46 +296,6 @@ class BaseDatasourceData(TypedDict, total=False):
normalize_columns: bool
class QueryData(TypedDict, total=False):
"""
TypedDict for SQL Lab query data returned to the frontend.
This represents the structure of the dictionary returned from Query.data property
in SQL Lab. It provides query information to the frontend for execution and display.
Fields:
time_grain_sqla: Available time grains for this database
filter_select: Whether filter select is enabled
name: Query tab name
columns: List of column definitions
metrics: List of metrics (always empty for queries)
id: Query ID
type: Object type (always "query")
sql: SQL query text
owners: List of owner information
database: Database connection details
order_by_choices: Available ordering options
catalog: Catalog name if applicable
schema: Schema name if applicable
verbose_map: Mapping of column names to verbose names (empty for queries)
"""
time_grain_sqla: list[tuple[Any, Any]]
filter_select: bool
name: str | None
columns: list[dict[str, Any]]
metrics: list[Any]
id: int
type: str
sql: str | None
owners: list[dict[str, Any]]
database: dict[str, Any]
order_by_choices: list[tuple[str, str]]
catalog: str | None
schema: str | None
verbose_map: dict[str, str]
VizData: TypeAlias = list[Any] | dict[Any, Any] | None
VizPayload: TypeAlias = dict[str, Any]

View File

@@ -96,6 +96,7 @@ from superset.exceptions import (
SupersetException,
SupersetTimeoutException,
)
from superset.explorables.base import Explorable
from superset.sql.parse import sanitize_clause
from superset.superset_typing import (
AdhocColumn,
@@ -114,10 +115,8 @@ from superset.utils.hashing import md5_sha_from_dict, md5_sha_from_str
from superset.utils.pandas import detect_datetime_format
if TYPE_CHECKING:
from superset.connectors.sqla.models import BaseDatasource, TableColumn
from superset.explorables.base import Explorable
from superset.connectors.sqla.models import TableColumn
from superset.models.core import Database
from superset.models.sql_lab import Query
logging.getLogger("MARKDOWN").setLevel(logging.INFO)
logger = logging.getLogger(__name__)
@@ -1657,9 +1656,7 @@ def map_sql_type_to_inferred_type(sql_type: Optional[str]) -> str:
return "string" # If no match is found, return "string" as default
def get_metric_type_from_column(
column: Any, datasource: BaseDatasource | Explorable | Query
) -> str:
def get_metric_type_from_column(column: Any, datasource: Explorable) -> str:
"""
Determine the metric type from a given column in a datasource.
@@ -1701,7 +1698,7 @@ def get_metric_type_from_column(
def extract_dataframe_dtypes(
df: pd.DataFrame,
datasource: BaseDatasource | Explorable | Query | None = None,
datasource: Explorable | None = None,
) -> list[GenericDataType]:
"""Serialize pandas/numpy dtypes to generic types"""
@@ -1772,7 +1769,7 @@ def is_test() -> bool:
def get_time_filter_status(
datasource: BaseDatasource | Explorable,
datasource: Explorable,
applied_time_extras: dict[str, str],
) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
temporal_columns: set[Any] = {

View File

@@ -78,9 +78,8 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.user_attributes import UserAttribute
from superset.superset_typing import (
BaseDatasourceData,
ExplorableData,
FlaskResponse,
QueryData,
)
from superset.tasks.utils import get_current_user
from superset.utils import core as utils, json
@@ -531,14 +530,14 @@ class Superset(BaseSupersetView):
)
standalone_mode = ReservedUrlParameters.is_standalone_mode()
force = request.args.get("force") in {"force", "1", "true"}
dummy_datasource_data: BaseDatasourceData = {
dummy_datasource_data: ExplorableData = {
"type": datasource_type or "unknown",
"name": datasource_name,
"columns": [],
"metrics": [],
"database": {"id": 0, "backend": ""},
}
datasource_data: BaseDatasourceData | QueryData
datasource_data: ExplorableData
try:
datasource_data = datasource.data if datasource else dummy_datasource_data
except (SupersetException, SQLAlchemyError):

View File

@@ -46,10 +46,9 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query
from superset.superset_typing import (
BaseDatasourceData,
ExplorableData,
FlaskResponse,
FormData,
QueryData,
)
from superset.utils import json
from superset.utils.core import DatasourceType
@@ -92,12 +91,10 @@ def redirect_to_login(next_target: str | None = None) -> FlaskResponse:
def sanitize_datasource_data(
datasource_data: BaseDatasourceData | QueryData,
datasource_data: ExplorableData,
) -> dict[str, Any]:
"""
Sanitize datasource data by removing sensitive database parameters.
Accepts TypedDict types (BaseDatasourceData, QueryData).
"""
if datasource_data:
datasource_database = datasource_data.get("database")