Small fixes

This commit is contained in:
Beto Dealmeida
2025-11-26 14:46:44 -05:00
parent 99525c1ce9
commit a36bbf8ffd
10 changed files with 31 additions and 32 deletions

View File

@@ -68,14 +68,7 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand):
query_obj = self._query_context.queries[0] query_obj = self._query_context.queries[0]
sql_query = datasource.get_query_str(query_obj.to_dict()) sql_query = datasource.get_query_str(query_obj.to_dict())
# Chart export is SQL-specific, so we check for BaseDatasource return sql_query, getattr(datasource, "database", None)
from superset.connectors.sqla.models import BaseDatasource
database = (
datasource.database if isinstance(datasource, BaseDatasource) else None
)
return sql_query, database
def _get_row_limit(self) -> int | None: def _get_row_limit(self) -> int | None:
""" """

View File

@@ -84,10 +84,7 @@ class ExportDatasetsCommand(ExportModelsCommand):
# SQLAlchemy returns column names as quoted_name objects which PyYAML cannot # SQLAlchemy returns column names as quoted_name objects which PyYAML cannot
# serialize. Convert all keys to regular strings to fix YAML serialization. # serialize. Convert all keys to regular strings to fix YAML serialization.
try: try:
from sqlalchemy.sql.elements import quoted_name payload = {str(key): value for key, value in payload.items()}
if any(isinstance(key, quoted_name) for key in payload.keys()):
payload = {str(key): value for key, value in payload.items()}
except ImportError: except ImportError:
pass pass

View File

@@ -73,7 +73,7 @@ def _get_query(
_: bool, _: bool,
) -> dict[str, Any]: ) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj) datasource = _get_datasource(query_context, query_obj)
result = {"language": getattr(datasource, "query_language", None)} result = {"language": datasource.query_language}
try: try:
result["query"] = datasource.get_query_str(query_obj.to_dict()) result["query"] = datasource.get_query_str(query_obj.to_dict())
except QueryObjectValidationError as err: except QueryObjectValidationError as err:

View File

@@ -483,6 +483,6 @@ class QueryContextProcessor:
query.validate() query.validate()
if self._qc_datasource.type == DatasourceType.QUERY: if self._qc_datasource.type == DatasourceType.QUERY:
security_manager.raise_for_access(datasource=self._qc_datasource) security_manager.raise_for_access(query=self._qc_datasource)
else: else:
security_manager.raise_for_access(query_context=self._query_context) security_manager.raise_for_access(query_context=self._query_context)

View File

@@ -225,7 +225,7 @@ class BaseDatasource(
This allows each datasource to override caching, while falling back This allows each datasource to override caching, while falling back
to database-level defaults when appropriate. to database-level defaults when appropriate.
""" """
if self._cache_timeout: if self._cache_timeout is not None:
return self._cache_timeout return self._cache_timeout
# database should always be set, but that's not true for v0 import # database should always be set, but that's not true for v0 import

View File

@@ -110,6 +110,21 @@ class Explorable(Protocol):
:return: Type identifier string :return: Type identifier string
""" """
@property
def metrics(self) -> list[Any]:
"""
List of metric metadata objects.
Each object should provide at minimum:
- metric_name: str - the metric's name
- expression: str - the metric's calculation expression
Used for validation, autocomplete, and query building.
:return: List of metric metadata objects
"""
# TODO: rename to dimensions
@property @property
def columns(self) -> list[Any]: def columns(self) -> list[Any]:
""" """
@@ -125,6 +140,7 @@ class Explorable(Protocol):
:return: List of column metadata objects :return: List of column metadata objects
""" """
# TODO: remove and use columns instead
@property @property
def column_names(self) -> list[str]: def column_names(self) -> list[str]:
""" """
@@ -136,6 +152,7 @@ class Explorable(Protocol):
:return: List of column name strings :return: List of column name strings
""" """
# TODO: use TypedDict for return type
@property @property
def data(self) -> dict[str, Any]: def data(self) -> dict[str, Any]:
""" """
@@ -289,11 +306,7 @@ class Explorable(Protocol):
""" """
# ========================================================================= # =========================================================================
# Optional Properties # Drilling
# =========================================================================
# =========================================================================
# Required Methods
# ========================================================================= # =========================================================================
def has_drill_by_columns(self, column_names: list[str]) -> bool: def has_drill_by_columns(self, column_names: list[str]) -> bool:

View File

@@ -467,7 +467,9 @@ class ImportExportMixin(UUIDMixin):
if parent_ref: if parent_ref:
parent_excludes = {c.name for c in parent_ref.local_columns} parent_excludes = {c.name for c in parent_ref.local_columns}
dict_rep = { dict_rep = {
c.name: getattr(self, c.name) # Convert c.name to str to handle SQLAlchemy's quoted_name type
# which is not YAML-serializable
str(c.name): getattr(self, c.name)
for c in cls.__table__.columns # type: ignore for c in cls.__table__.columns # type: ignore
if ( if (
c.name in export_fields c.name in export_fields

View File

@@ -61,7 +61,6 @@ def _adjust_string_with_rls(
""" """
Add the RLS filters to the unique string based on current executor. Add the RLS filters to the unique string based on current executor.
""" """
from superset.connectors.sqla.models import BaseDatasource
user = ( user = (
security_manager.find_user(executor) security_manager.find_user(executor)
@@ -72,11 +71,7 @@ def _adjust_string_with_rls(
stringified_rls = "" stringified_rls = ""
with override_user(user): with override_user(user):
for datasource in datasources: for datasource in datasources:
if ( if datasource and getattr(datasource, "is_rls_supported", False):
datasource
and isinstance(datasource, BaseDatasource)
and datasource.is_rls_supported
):
rls_filters = datasource.get_sqla_row_level_filters() rls_filters = datasource.get_sqla_row_level_filters()
if len(rls_filters) > 0: if len(rls_filters) > 0:

View File

@@ -1677,10 +1677,6 @@ def get_metric_type_from_column(
from superset.connectors.sqla.models import SqlMetric from superset.connectors.sqla.models import SqlMetric
# Explorable datasources may not have metrics attribute
if datasource is None or not hasattr(datasource, "metrics"):
return ""
metric: SqlMetric = next( metric: SqlMetric = next(
(metric for metric in datasource.metrics if metric.metric_name == column), (metric for metric in datasource.metrics if metric.metric_name == column),
SqlMetric(metric_name=""), SqlMetric(metric_name=""),

View File

@@ -16,6 +16,8 @@
# #
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from superset.connectors.sqla.models import SqlMetric from superset.connectors.sqla.models import SqlMetric
from superset.utils.core import ( from superset.utils.core import (
get_metric_type_from_column, get_metric_type_from_column,
@@ -60,7 +62,8 @@ def test_column_is_none():
def test_datasource_is_none(): def test_datasource_is_none():
datasource = None datasource = None
column = "my_column" column = "my_column"
assert get_metric_type_from_column(column, datasource) == "" with pytest.raises(AttributeError):
get_metric_type_from_column(column, datasource)
def test_none_input(): def test_none_input():