diff --git a/superset/commands/explore/get.py b/superset/commands/explore/get.py index df71a931b2c..70de735d7b0 100644 --- a/superset/commands/explore/get.py +++ b/superset/commands/explore/get.py @@ -37,6 +37,7 @@ from superset.exceptions import SupersetException from superset.explore.exceptions import WrongEndpointError from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError from superset.extensions import security_manager +from superset.superset_typing import BaseDatasourceData, QueryData from superset.utils import core as utils, json from superset.views.utils import ( get_datasource_info, @@ -135,9 +136,8 @@ class GetExploreCommand(BaseCommand, ABC): utils.merge_extra_filters(form_data) utils.merge_request_params(form_data, request.args) - # TODO: this is a dummy placeholder - should be refactored to being just `None` - datasource_data: dict[str, Any] = { - "type": self._datasource_type, + datasource_data: BaseDatasourceData | QueryData = { + "type": self._datasource_type or "unknown", "name": datasource_name, "columns": [], "metrics": [], diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 0740ca889da..9ef1d10ba92 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -36,7 +36,7 @@ from superset.exceptions import ( ) from superset.extensions import event_logger from superset.sql.parse import sanitize_clause -from superset.superset_typing import Column, Metric, OrderBy +from superset.superset_typing import Column, Metric, OrderBy, QueryObjectDict from superset.utils import json, pandas_postprocessing from superset.utils.core import ( DTTM_ALIAS, @@ -370,8 +370,8 @@ class QueryObject: # pylint: disable=too-many-instance-attributes ) ) - def to_dict(self) -> dict[str, Any]: - query_object_dict = { + def to_dict(self) -> QueryObjectDict: + query_object_dict: QueryObjectDict = { "apply_fetch_values_predicate": self.apply_fetch_values_predicate, "columns": self.columns, "extras": self.extras, @@ -412,7 +412,8 @@ class QueryObject: # pylint: disable=too-many-instance-attributes the use-provided inputs to bounds, which may be time-relative (as in "5 days ago" or "now"). """ - cache_dict = self.to_dict() + # Cast to dict[str, Any] for mutation operations + cache_dict: dict[str, Any] = dict(self.to_dict()) cache_dict.update(extra) # TODO: the below KVs can all be cleaned up and moved to `to_dict()` at some diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e0fbfbd665c..3afe9fabeae 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -101,12 +101,14 @@ from superset.models.helpers import ( ExploreMixin, ImportExportMixin, QueryResult, + SQLA_QUERY_KEYS, ) from superset.models.slice import Slice from superset.sql.parse import Table from superset.superset_typing import ( AdhocColumn, AdhocMetric, + BaseDatasourceData, Metric, QueryObjectDict, ResultSetColumnType, @@ -135,8 +137,6 @@ class MetadataResult: modified: list[str] = field(default_factory=list) -logger = logging.getLogger(__name__) - METRIC_FORM_DATA_PARAMS = [ "metric", "metric_2", @@ -166,7 +166,10 @@ class DatasourceKind(StrEnum): PHYSICAL = "physical" -class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class BaseDatasource( + AuditMixinNullable, + ImportExportMixin, +): # pylint: disable=too-many-public-methods """A common interface to objects that are queryable (tables and datasources)""" @@ -363,7 +366,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable= return verb_map @property - def data(self) -> dict[str, Any]: + def data(self) -> BaseDatasourceData: """Data representation of the datasource sent to the frontend""" return { # simple fields @@ -408,7 +411,8 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable= Used to reduce the payload when loading a dashboard. """ - data = self.data + # Cast to dict[str, Any] since we'll be mutating with del and .update() + data = cast(dict[str, Any], self.data) metric_names = set() column_names = set() for slc in slices: @@ -471,14 +475,15 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable= or metric["verbose_name"] in metric_names ] - filtered_columns: list[Column] = [] + filtered_columns: list[dict[str, Any]] = [] column_types: set[utils.GenericDataType] = set() - for column_ in data["columns"]: - generic_type = column_.get("type_generic") + for column_ in cast(list[dict[str, Any]], data["columns"]): # type: ignore[assignment] + column_dict = cast(dict[str, Any], column_) + generic_type = column_dict.get("type_generic") if generic_type is not None: column_types.add(generic_type) - if column_["column_name"] in column_names: - filtered_columns.append(column_) + if column_dict["column_name"] in column_names: + filtered_columns.append(column_dict) data["column_types"] = list(column_types) del data["description"] @@ -510,7 +515,8 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable= """Returns a query as a string This is used to be displayed to the user so that they can - understand what is taking place behind the scene""" + understand what is taking place behind the scene + """ raise NotImplementedError() def query(self, query_obj: QueryObjectDict) -> QueryResult: @@ -614,7 +620,8 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable= """If a datasource needs to provide additional keys for calculation of cache keys, those can be provided via this method - :param query_obj: The dict representation of a query object + :param query_obj: The dict representation of a query object (QueryObjectDict + structure expected) :return: list of keys """ return [] @@ -1352,7 +1359,7 @@ class SqlaTable( return [(g.duration, g.name) for g in self.database.grains() or []] @property - def data(self) -> dict[str, Any]: + def data(self) -> BaseDatasourceData: data_ = super().data if self.type == "table": data_["granularity_sqla"] = self.granularity_sqla @@ -1867,7 +1874,7 @@ class SqlaTable( template code unnecessarily, as it may contain expensive calls, e.g. to extract the latest partition of a database. - :param query_obj: query object to analyze + :param query_obj: query object to analyze (QueryObjectDict structure expected) :return: True if there are call(s) to an `ExtraCache` method, False otherwise """ templatable_statements: list[str] = [] @@ -1924,7 +1931,11 @@ class SqlaTable( extra_cache_keys = super().get_extra_cache_keys(query_obj) if self.has_extra_cache_key_calls(query_obj): - sqla_query = self.get_sqla_query(**query_obj) + # Filter out keys that aren't parameters to get_sqla_query + filtered_query_obj = { + k: v for k, v in query_obj.items() if k in SQLA_QUERY_KEYS + } + sqla_query = self.get_sqla_query(**cast(Any, filtered_query_obj)) extra_cache_keys += sqla_query.extra_cache_keys # For virtual datasets, include RLS predicates in the cache key diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 1634af83929..673e6d3815f 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -46,6 +46,7 @@ from superset.exceptions import ( ) from superset.extensions import feature_flag_manager from superset.sql.parse import Table +from superset.superset_typing import Column, QueryObjectDict from superset.utils import json from superset.utils.core import ( AdhocFilterClause, @@ -1007,11 +1008,11 @@ def dataset_macro( columns = columns or [column.column_name for column in dataset.columns] metrics = [metric.metric_name for metric in dataset.metrics] - query_obj = { + query_obj: QueryObjectDict = { "is_timeseries": False, "filter": [], "metrics": metrics if include_metrics else None, - "columns": columns, + "columns": cast(list[Column], columns), "from_dttm": from_dttm, "to_dttm": to_dttm, } diff --git a/superset/models/helpers.py b/superset/models/helpers.py index ea8526bd732..85bc4a7c980 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -120,6 +120,34 @@ logger = logging.getLogger(__name__) VIRTUAL_TABLE_ALIAS = "virtual_table" SERIES_LIMIT_SUBQ_ALIAS = "series_limit" +# Keys used to filter QueryObjectDict for get_sqla_query parameters +SQLA_QUERY_KEYS = { + "apply_fetch_values_predicate", + "columns", + "extras", + "filter", + "from_dttm", + "granularity", + "groupby", + "inner_from_dttm", + "inner_to_dttm", + "is_rowcount", + "is_timeseries", + "metrics", + "orderby", + "order_desc", + "to_dttm", + "series_columns", + "series_limit", + "series_limit_metric", + "group_others_when_limit_reached", + "row_limit", + "row_offset", + "timeseries_limit", + "timeseries_limit_metric", + "time_shift", +} + def validate_adhoc_subquery( sql: str, @@ -824,7 +852,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def columns(self) -> list[Any]: raise NotImplementedError() - def get_extra_cache_keys(self, query_obj: dict[str, Any]) -> list[Hashable]: + def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]: raise NotImplementedError() def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: @@ -974,7 +1002,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods query_obj: QueryObjectDict, mutate: bool = True, ) -> QueryStringExtended: - sqlaq = self.get_sqla_query(**query_obj) + # Filter out keys that aren't parameters to get_sqla_query + filtered_query_obj = { + k: v for k, v in query_obj.items() if k in SQLA_QUERY_KEYS + } + sqlaq = self.get_sqla_query(**cast(Any, filtered_query_obj)) sql = self.database.compile_sqla_query( sqlaq.sqla_query, catalog=self.catalog, @@ -2491,7 +2523,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods ] # run prequery to get top groups - prequery_obj = { + prequery_obj: QueryObjectDict = { "is_timeseries": False, "row_limit": series_limit, "metrics": metrics, @@ -2499,7 +2531,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods "groupby": groupby, "from_dttm": inner_from_dttm or from_dttm, "to_dttm": inner_to_dttm or to_dttm, - "filter": filter, + "filter": filter or [], "orderby": orderby, "extras": extras, "columns": get_non_base_axis_columns(columns), diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index f83ec05be53..3c3acb2570c 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -63,6 +63,7 @@ from superset.sql.parse import ( Table, ) from superset.sqllab.limiting_factor import LimitingFactor +from superset.superset_typing import QueryData, QueryObjectDict from superset.utils import json from superset.utils.core import ( get_column_name, @@ -238,7 +239,8 @@ class Query( return None @property - def data(self) -> dict[str, Any]: + def data(self) -> QueryData: + """Returns query data for the frontend""" order_by_choices = [] for col in self.columns: column_name = str(col.column_name or "") @@ -330,7 +332,7 @@ class Query( def default_endpoint(self) -> str: return "" - def get_extra_cache_keys(self, query_obj: dict[str, Any]) -> list[Hashable]: + def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]: return [] @property diff --git a/superset/superset_typing.py b/superset/superset_typing.py index c3460d80823..d182e5eb490 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -14,59 +14,61 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from collections.abc import Sequence +from __future__ import annotations + +from collections.abc import Hashable, Sequence from datetime import datetime -from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict, Union +from typing import Any, Literal, TYPE_CHECKING, TypeAlias, TypedDict from sqlalchemy.sql.type_api import TypeEngine from typing_extensions import NotRequired from werkzeug.wrappers import Response if TYPE_CHECKING: - from superset.utils.core import GenericDataType + from superset.utils.core import GenericDataType, QueryObjectFilterClause -SQLType = Union[TypeEngine, type[TypeEngine]] +SQLType: TypeAlias = TypeEngine | type[TypeEngine] class LegacyMetric(TypedDict): - label: Optional[str] + label: str | None class AdhocMetricColumn(TypedDict, total=False): - column_name: Optional[str] - description: Optional[str] - expression: Optional[str] + column_name: str | None + description: str | None + expression: str | None filterable: bool groupby: bool id: int is_dttm: bool - python_date_format: Optional[str] + python_date_format: str | None type: str type_generic: "GenericDataType" - verbose_name: Optional[str] + verbose_name: str | None class AdhocMetric(TypedDict, total=False): aggregate: str - column: Optional[AdhocMetricColumn] + column: AdhocMetricColumn | None expressionType: Literal["SIMPLE", "SQL"] - hasCustomLabel: Optional[bool] - label: Optional[str] - sqlExpression: Optional[str] + hasCustomLabel: bool | None + label: str | None + sqlExpression: str | None class AdhocColumn(TypedDict, total=False): - hasCustomLabel: Optional[bool] + hasCustomLabel: bool | None label: str sqlExpression: str - isColumnReference: Optional[bool] - columnType: Optional[Literal["BASE_AXIS", "SERIES"]] - timeGrain: Optional[str] + isColumnReference: bool | None + columnType: Literal["BASE_AXIS", "SERIES"] | None + timeGrain: str | None class SQLAColumnType(TypedDict): name: str - type: Optional[str] + type: str | None is_dttm: bool @@ -77,9 +79,9 @@ class ResultSetColumnType(TypedDict): name: str # legacy naming convention keeping this for backwards compatibility column_name: str - type: Optional[Union[SQLType, str]] - is_dttm: Optional[bool] - type_generic: NotRequired[Optional["GenericDataType"]] + type: SQLType | str | None + is_dttm: bool | None + type_generic: NotRequired["GenericDataType" | None] nullable: NotRequired[Any] default: NotRequired[Any] @@ -91,40 +93,258 @@ class ResultSetColumnType(TypedDict): query_as: NotRequired[Any] -CacheConfig = dict[str, Any] -DbapiDescriptionRow = tuple[ - Union[str, bytes], +CacheConfig: TypeAlias = dict[str, Any] +DbapiDescriptionRow: TypeAlias = tuple[ + str | bytes, str, - Optional[str], - Optional[str], - Optional[int], - Optional[int], + str | None, + str | None, + int | None, + int | None, bool, ] -DbapiDescription = Union[list[DbapiDescriptionRow], tuple[DbapiDescriptionRow, ...]] -DbapiResult = Sequence[Union[list[Any], tuple[Any, ...]]] -FilterValue = Union[bool, datetime, float, int, str] -FilterValues = Union[FilterValue, list[FilterValue], tuple[FilterValue]] -FormData = dict[str, Any] -Granularity = Union[str, dict[str, Union[str, float]]] -Column = Union[AdhocColumn, str] -Metric = Union[AdhocMetric, str] -OrderBy = tuple[Union[Metric, Column], bool] -QueryObjectDict = dict[str, Any] -VizData = Optional[Union[list[Any], dict[Any, Any]]] -VizPayload = dict[str, Any] +DbapiDescription: TypeAlias = ( + list[DbapiDescriptionRow] | tuple[DbapiDescriptionRow, ...] +) +DbapiResult: TypeAlias = Sequence[list[Any] | tuple[Any, ...]] +FilterValue: TypeAlias = bool | datetime | float | int | str +FilterValues: TypeAlias = FilterValue | list[FilterValue] | tuple[FilterValue] +FormData: TypeAlias = dict[str, Any] +Granularity: TypeAlias = str | dict[str, str | float] +Column: TypeAlias = AdhocColumn | str +Metric: TypeAlias = AdhocMetric | str +OrderBy: TypeAlias = tuple[Metric | Column, bool] + + +class QueryObjectDict(TypedDict, total=False): + """ + TypedDict representation of query objects used throughout Superset. + + This represents the dictionary output from QueryObject.to_dict() and is used + in datasource query methods throughout Superset. + + Core fields from QueryObject.to_dict(): + apply_fetch_values_predicate: Whether to apply fetch values predicate + columns: List of columns to include + extras: Additional options and parameters + filter: List of filter clauses + from_dttm: Start datetime for time range + granularity: Time grain/granularity + inner_from_dttm: Inner start datetime for nested queries + inner_to_dttm: Inner end datetime for nested queries + is_rowcount: Whether this is a row count query + is_timeseries: Whether this is a timeseries query + metrics: List of metrics to compute + order_desc: Whether to order descending + orderby: List of order by clauses + row_limit: Maximum number of rows + row_offset: Number of rows to skip + series_columns: Columns to use for series + series_limit: Maximum number of series + series_limit_metric: Metric to use for series limiting + group_others_when_limit_reached: Whether to group remaining items as "Others" + to_dttm: End datetime for time range + time_shift: Time shift specification + + Additional fields used throughout the codebase: + time_range: Human-readable time range string + datasource: BaseDatasource instance + extra_cache_keys: Additional keys for caching + rls: Row level security filters + changed_on: Last modified timestamp + + Deprecated fields (still in use): + groupby: Columns to group by (use columns instead) + timeseries_limit: Series limit (use series_limit instead) + timeseries_limit_metric: Series limit metric (use series_limit_metric instead) + """ + + # Core fields from QueryObject.to_dict() + apply_fetch_values_predicate: bool + columns: list[Column] + extras: dict[str, Any] + filter: list["QueryObjectFilterClause"] + from_dttm: datetime | None + granularity: str | None + inner_from_dttm: datetime | None + inner_to_dttm: datetime | None + is_rowcount: bool + is_timeseries: bool + metrics: list[Metric] | None + order_desc: bool + orderby: list[OrderBy] + row_limit: int | None + row_offset: int + series_columns: list[Column] + series_limit: int + series_limit_metric: Metric | None + group_others_when_limit_reached: bool + to_dttm: datetime | None + time_shift: str | None + + # Additional fields used throughout the codebase + time_range: str | None + datasource: Any # BaseDatasource instance + extra_cache_keys: list[Hashable] + rls: list[Any] + changed_on: datetime | None + + # Deprecated fields (still in use) + groupby: list[Column] + timeseries_limit: int + timeseries_limit_metric: Metric | None + + +class BaseDatasourceData(TypedDict, total=False): + """ + TypedDict for datasource data returned to the frontend. + + This represents the structure of the dictionary returned from BaseDatasource.data + property. It provides datasource information to the frontend for visualization + and querying. + + Core fields from BaseDatasource.data: + id: Unique identifier for the datasource + uid: Unique identifier including type (e.g., "1__table") + column_formats: D3 format strings for columns + description: Human-readable description + database: Database connection information + default_endpoint: Default URL endpoint for this datasource + filter_select: Whether filter select is enabled (deprecated) + filter_select_enabled: Whether filter select is enabled + name: Display name of the datasource + datasource_name: Name of the underlying table/query + table_name: Table name (same as datasource_name) + type: Datasource type (e.g., "table", "query") + catalog: Catalog name if applicable + schema: Schema name if applicable + offset: Default row offset + cache_timeout: Cache timeout in seconds + params: Additional parameters as JSON string + perm: Permission string + edit_url: URL to edit this datasource + sql: SQL query for virtual datasets + columns: List of column definitions + metrics: List of metric definitions + folders: Folder structure (JSON field) + order_by_choices: Available ordering options + owners: List of owner IDs or owner details + verbose_map: Mapping of column/metric names to verbose names + select_star: SELECT * query for this datasource + + Additional fields from SqlaTable and data_for_slices: + column_types: List of column data types + column_names: Set of column names + granularity_sqla: Available time granularities + time_grain_sqla: Available time grains + main_dttm_col: Main datetime column + fetch_values_predicate: Predicate for fetching filter values + template_params: Template parameters for Jinja + is_sqllab_view: Whether this is a SQL Lab view + health_check_message: Health check status message + extra: Extra configuration as JSON string + always_filter_main_dttm: Whether to always filter on main datetime + normalize_columns: Whether to normalize column names + """ + + # Core fields from BaseDatasource.data + id: int + uid: str + column_formats: dict[str, str | None] + description: str | None + database: dict[str, Any] + default_endpoint: str | None + filter_select: bool + filter_select_enabled: bool + name: str + datasource_name: str + table_name: str + type: str + catalog: str | None + schema: str | None + offset: int + cache_timeout: int | None + params: str | None + perm: str | None + edit_url: str + sql: str | None + columns: list[dict[str, Any]] + metrics: list[dict[str, Any]] + folders: Any # JSON field, can be list or dict + order_by_choices: list[tuple[str, str]] + owners: list[int] | list[dict[str, Any]] # Can be either format + verbose_map: dict[str, str] + select_star: str | None + + # Additional fields from SqlaTable and data_for_slices + column_types: list[Any] + column_names: set[str] | set[Any] + granularity_sqla: list[tuple[Any, Any]] + time_grain_sqla: list[tuple[Any, Any]] + main_dttm_col: str | None + fetch_values_predicate: str | None + template_params: str | None + is_sqllab_view: bool + health_check_message: str | None + extra: str | None + always_filter_main_dttm: bool + normalize_columns: bool + + +class QueryData(TypedDict, total=False): + """ + TypedDict for SQL Lab query data returned to the frontend. + + This represents the structure of the dictionary returned from Query.data property + in SQL Lab. It provides query information to the frontend for execution and display. + + Fields: + time_grain_sqla: Available time grains for this database + filter_select: Whether filter select is enabled + name: Query tab name + columns: List of column definitions + metrics: List of metrics (always empty for queries) + id: Query ID + type: Object type (always "query") + sql: SQL query text + owners: List of owner information + database: Database connection details + order_by_choices: Available ordering options + catalog: Catalog name if applicable + schema: Schema name if applicable + verbose_map: Mapping of column names to verbose names (empty for queries) + """ + + time_grain_sqla: list[tuple[Any, Any]] + filter_select: bool + name: str | None + columns: list[dict[str, Any]] + metrics: list[Any] + id: int + type: str + sql: str | None + owners: list[dict[str, Any]] + database: dict[str, Any] + order_by_choices: list[tuple[str, str]] + catalog: str | None + schema: str | None + verbose_map: dict[str, str] + + +VizData: TypeAlias = list[Any] | dict[Any, Any] | None +VizPayload: TypeAlias = dict[str, Any] # Flask response. -Base = Union[bytes, str] -Status = Union[int, str] -Headers = dict[str, Any] -FlaskResponse = Union[ - Response, - Base, - tuple[Base, Status], - tuple[Base, Status, Headers], - tuple[Response, Status], -] +Base: TypeAlias = bytes | str +Status: TypeAlias = int | str +Headers: TypeAlias = dict[str, Any] +FlaskResponse: TypeAlias = ( + Response + | Base + | tuple[Base, Status] + | tuple[Base, Status, Headers] + | tuple[Response, Status] +) class OAuth2ClientConfig(TypedDict): diff --git a/superset/utils/core.py b/superset/utils/core.py index aba71e2b3eb..9acd7e03c7f 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1618,7 +1618,9 @@ def get_column_name_from_metric(metric: Metric) -> str | None: if is_adhoc_metric(metric): metric = cast(AdhocMetric, metric) if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE: - return cast(dict[str, Any], metric["column"])["column_name"] + column = metric["column"] + if column: + return column["column_name"] return None diff --git a/superset/views/core.py b/superset/views/core.py index 3c67d21ffa8..f888e889b75 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -76,9 +76,12 @@ from superset.extensions import async_query_manager, cache_manager from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice -from superset.models.sql_lab import Query from superset.models.user_attributes import UserAttribute -from superset.superset_typing import FlaskResponse +from superset.superset_typing import ( + BaseDatasourceData, + FlaskResponse, + QueryData, +) from superset.tasks.utils import get_current_user from superset.utils import core as utils, json from superset.utils.cache import etag_cache @@ -528,13 +531,14 @@ class Superset(BaseSupersetView): ) standalone_mode = ReservedUrlParameters.is_standalone_mode() force = request.args.get("force") in {"force", "1", "true"} - dummy_datasource_data: dict[str, Any] = { - "type": datasource_type, + dummy_datasource_data: BaseDatasourceData = { + "type": datasource_type or "unknown", "name": datasource_name, "columns": [], "metrics": [], "database": {"id": 0, "backend": ""}, } + datasource_data: BaseDatasourceData | QueryData try: datasource_data = datasource.data if datasource else dummy_datasource_data except (SupersetException, SQLAlchemyError): @@ -542,8 +546,6 @@ class Superset(BaseSupersetView): if datasource: datasource_data["owners"] = datasource.owners_data - if isinstance(datasource, Query): - datasource_data["columns"] = datasource.columns bootstrap_data = { "can_add": slice_add_perm, diff --git a/superset/views/utils.py b/superset/views/utils.py index 86f72ef376b..d30e8c1c2b7 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -45,7 +45,12 @@ from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import Query -from superset.superset_typing import FlaskResponse, FormData +from superset.superset_typing import ( + BaseDatasourceData, + FlaskResponse, + FormData, + QueryData, +) from superset.utils import json from superset.utils.core import DatasourceType from superset.utils.decorators import stats_timing @@ -86,13 +91,20 @@ def redirect_to_login(next_target: str | None = None) -> FlaskResponse: return redirect(redirect_url) -def sanitize_datasource_data(datasource_data: dict[str, Any]) -> dict[str, Any]: +def sanitize_datasource_data( + datasource_data: BaseDatasourceData | QueryData, +) -> dict[str, Any]: + """ + Sanitize datasource data by removing sensitive database parameters. + + Accepts TypedDict types (BaseDatasourceData, QueryData). + """ if datasource_data: datasource_database = datasource_data.get("database") if datasource_database: datasource_database["parameters"] = {} - return datasource_data + return datasource_data # type: ignore[return-value] def bootstrap_user_data(user: User, include_perms: bool = False) -> dict[str, Any]: diff --git a/superset/viz.py b/superset/viz.py index b9c242628cb..2e6086ab51b 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -458,7 +458,9 @@ class BaseViz: # pylint: disable=too-many-public-methods different time shifts will differ only in the `from_dttm`, `to_dttm`, `inner_from_dttm`, and `inner_to_dttm` values which are stripped. """ - cache_dict = copy.copy(query_obj) + # Cast to dict[str, Any] to allow mutable operations (update, del) + # since TypedDict doesn't support these operations in the same way + cache_dict: dict[str, Any] = copy.copy(cast(dict[str, Any], query_obj)) cache_dict.update(extra) for k in ["from_dttm", "to_dttm", "inner_from_dttm", "inner_to_dttm"]: diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 0852c62362d..f770b727319 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -19,7 +19,7 @@ from __future__ import annotations import re from datetime import datetime -from typing import Any, Literal, NamedTuple, Optional, Union +from typing import Any, cast, Literal, NamedTuple, Optional, Union from re import Pattern from unittest.mock import Mock, patch import pytest @@ -35,6 +35,7 @@ from sqlalchemy.sql.elements import TextClause from superset import db from superset.connectors.sqla.models import SqlaTable, TableColumn, SqlMetric from superset.constants import EMPTY_STRING, NULL_STRING +from superset.superset_typing import QueryObjectDict from superset.db_engine_specs.bigquery import BigQueryEngineSpec from superset.db_engine_specs.druid import DruidEngineSpec from superset.exceptions import ( @@ -975,8 +976,11 @@ def test_extra_cache_keys_in_adhoc_metrics_and_columns( query_obj = {**base_query_obj, **items} - extra_cache_keys = table.get_extra_cache_keys(query_obj) - assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys + extra_cache_keys = table.get_extra_cache_keys(cast(QueryObjectDict, query_obj)) + assert ( + table.has_extra_cache_key_calls(cast(QueryObjectDict, query_obj)) + == has_extra_cache_keys + ) assert extra_cache_keys == expected_cache_keys @@ -1017,8 +1021,8 @@ def test_extra_cache_keys_in_dataset_metrics_and_columns( "filter": [], } - extra_cache_keys = table.get_extra_cache_keys(query_obj) - assert table.has_extra_cache_key_calls(query_obj) is True + extra_cache_keys = table.get_extra_cache_keys(cast(QueryObjectDict, query_obj)) + assert table.has_extra_cache_key_calls(cast(QueryObjectDict, query_obj)) is True assert set(extra_cache_keys) == {"abc", None}