mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
Small fixes
This commit is contained in:
@@ -68,14 +68,7 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand):
|
|||||||
query_obj = self._query_context.queries[0]
|
query_obj = self._query_context.queries[0]
|
||||||
sql_query = datasource.get_query_str(query_obj.to_dict())
|
sql_query = datasource.get_query_str(query_obj.to_dict())
|
||||||
|
|
||||||
# Chart export is SQL-specific, so we check for BaseDatasource
|
return sql_query, getattr(datasource, "database", None)
|
||||||
from superset.connectors.sqla.models import BaseDatasource
|
|
||||||
|
|
||||||
database = (
|
|
||||||
datasource.database if isinstance(datasource, BaseDatasource) else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return sql_query, database
|
|
||||||
|
|
||||||
def _get_row_limit(self) -> int | None:
|
def _get_row_limit(self) -> int | None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -84,10 +84,7 @@ class ExportDatasetsCommand(ExportModelsCommand):
|
|||||||
# SQLAlchemy returns column names as quoted_name objects which PyYAML cannot
|
# SQLAlchemy returns column names as quoted_name objects which PyYAML cannot
|
||||||
# serialize. Convert all keys to regular strings to fix YAML serialization.
|
# serialize. Convert all keys to regular strings to fix YAML serialization.
|
||||||
try:
|
try:
|
||||||
from sqlalchemy.sql.elements import quoted_name
|
payload = {str(key): value for key, value in payload.items()}
|
||||||
|
|
||||||
if any(isinstance(key, quoted_name) for key in payload.keys()):
|
|
||||||
payload = {str(key): value for key, value in payload.items()}
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ def _get_query(
|
|||||||
_: bool,
|
_: bool,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
datasource = _get_datasource(query_context, query_obj)
|
datasource = _get_datasource(query_context, query_obj)
|
||||||
result = {"language": getattr(datasource, "query_language", None)}
|
result = {"language": datasource.query_language}
|
||||||
try:
|
try:
|
||||||
result["query"] = datasource.get_query_str(query_obj.to_dict())
|
result["query"] = datasource.get_query_str(query_obj.to_dict())
|
||||||
except QueryObjectValidationError as err:
|
except QueryObjectValidationError as err:
|
||||||
|
|||||||
@@ -483,6 +483,6 @@ class QueryContextProcessor:
|
|||||||
query.validate()
|
query.validate()
|
||||||
|
|
||||||
if self._qc_datasource.type == DatasourceType.QUERY:
|
if self._qc_datasource.type == DatasourceType.QUERY:
|
||||||
security_manager.raise_for_access(datasource=self._qc_datasource)
|
security_manager.raise_for_access(query=self._qc_datasource)
|
||||||
else:
|
else:
|
||||||
security_manager.raise_for_access(query_context=self._query_context)
|
security_manager.raise_for_access(query_context=self._query_context)
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ class BaseDatasource(
|
|||||||
This allows each datasource to override caching, while falling back
|
This allows each datasource to override caching, while falling back
|
||||||
to database-level defaults when appropriate.
|
to database-level defaults when appropriate.
|
||||||
"""
|
"""
|
||||||
if self._cache_timeout:
|
if self._cache_timeout is not None:
|
||||||
return self._cache_timeout
|
return self._cache_timeout
|
||||||
|
|
||||||
# database should always be set, but that's not true for v0 import
|
# database should always be set, but that's not true for v0 import
|
||||||
|
|||||||
@@ -110,6 +110,21 @@ class Explorable(Protocol):
|
|||||||
:return: Type identifier string
|
:return: Type identifier string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metrics(self) -> list[Any]:
|
||||||
|
"""
|
||||||
|
List of metric metadata objects.
|
||||||
|
|
||||||
|
Each object should provide at minimum:
|
||||||
|
- metric_name: str - the metric's name
|
||||||
|
- expression: str - the metric's calculation expression
|
||||||
|
|
||||||
|
Used for validation, autocomplete, and query building.
|
||||||
|
|
||||||
|
:return: List of metric metadata objects
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: rename to dimensions
|
||||||
@property
|
@property
|
||||||
def columns(self) -> list[Any]:
|
def columns(self) -> list[Any]:
|
||||||
"""
|
"""
|
||||||
@@ -125,6 +140,7 @@ class Explorable(Protocol):
|
|||||||
:return: List of column metadata objects
|
:return: List of column metadata objects
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO: remove and use columns instead
|
||||||
@property
|
@property
|
||||||
def column_names(self) -> list[str]:
|
def column_names(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@@ -136,6 +152,7 @@ class Explorable(Protocol):
|
|||||||
:return: List of column name strings
|
:return: List of column name strings
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO: use TypedDict for return type
|
||||||
@property
|
@property
|
||||||
def data(self) -> dict[str, Any]:
|
def data(self) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -289,11 +306,7 @@ class Explorable(Protocol):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Optional Properties
|
# Drilling
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Required Methods
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
||||||
def has_drill_by_columns(self, column_names: list[str]) -> bool:
|
def has_drill_by_columns(self, column_names: list[str]) -> bool:
|
||||||
|
|||||||
@@ -467,7 +467,9 @@ class ImportExportMixin(UUIDMixin):
|
|||||||
if parent_ref:
|
if parent_ref:
|
||||||
parent_excludes = {c.name for c in parent_ref.local_columns}
|
parent_excludes = {c.name for c in parent_ref.local_columns}
|
||||||
dict_rep = {
|
dict_rep = {
|
||||||
c.name: getattr(self, c.name)
|
# Convert c.name to str to handle SQLAlchemy's quoted_name type
|
||||||
|
# which is not YAML-serializable
|
||||||
|
str(c.name): getattr(self, c.name)
|
||||||
for c in cls.__table__.columns # type: ignore
|
for c in cls.__table__.columns # type: ignore
|
||||||
if (
|
if (
|
||||||
c.name in export_fields
|
c.name in export_fields
|
||||||
|
|||||||
@@ -61,7 +61,6 @@ def _adjust_string_with_rls(
|
|||||||
"""
|
"""
|
||||||
Add the RLS filters to the unique string based on current executor.
|
Add the RLS filters to the unique string based on current executor.
|
||||||
"""
|
"""
|
||||||
from superset.connectors.sqla.models import BaseDatasource
|
|
||||||
|
|
||||||
user = (
|
user = (
|
||||||
security_manager.find_user(executor)
|
security_manager.find_user(executor)
|
||||||
@@ -72,11 +71,7 @@ def _adjust_string_with_rls(
|
|||||||
stringified_rls = ""
|
stringified_rls = ""
|
||||||
with override_user(user):
|
with override_user(user):
|
||||||
for datasource in datasources:
|
for datasource in datasources:
|
||||||
if (
|
if datasource and getattr(datasource, "is_rls_supported", False):
|
||||||
datasource
|
|
||||||
and isinstance(datasource, BaseDatasource)
|
|
||||||
and datasource.is_rls_supported
|
|
||||||
):
|
|
||||||
rls_filters = datasource.get_sqla_row_level_filters()
|
rls_filters = datasource.get_sqla_row_level_filters()
|
||||||
|
|
||||||
if len(rls_filters) > 0:
|
if len(rls_filters) > 0:
|
||||||
|
|||||||
@@ -1677,10 +1677,6 @@ def get_metric_type_from_column(
|
|||||||
|
|
||||||
from superset.connectors.sqla.models import SqlMetric
|
from superset.connectors.sqla.models import SqlMetric
|
||||||
|
|
||||||
# Explorable datasources may not have metrics attribute
|
|
||||||
if datasource is None or not hasattr(datasource, "metrics"):
|
|
||||||
return ""
|
|
||||||
|
|
||||||
metric: SqlMetric = next(
|
metric: SqlMetric = next(
|
||||||
(metric for metric in datasource.metrics if metric.metric_name == column),
|
(metric for metric in datasource.metrics if metric.metric_name == column),
|
||||||
SqlMetric(metric_name=""),
|
SqlMetric(metric_name=""),
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
#
|
#
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from superset.connectors.sqla.models import SqlMetric
|
from superset.connectors.sqla.models import SqlMetric
|
||||||
from superset.utils.core import (
|
from superset.utils.core import (
|
||||||
get_metric_type_from_column,
|
get_metric_type_from_column,
|
||||||
@@ -60,7 +62,8 @@ def test_column_is_none():
|
|||||||
def test_datasource_is_none():
|
def test_datasource_is_none():
|
||||||
datasource = None
|
datasource = None
|
||||||
column = "my_column"
|
column = "my_column"
|
||||||
assert get_metric_type_from_column(column, datasource) == ""
|
with pytest.raises(AttributeError):
|
||||||
|
get_metric_type_from_column(column, datasource)
|
||||||
|
|
||||||
|
|
||||||
def test_none_input():
|
def test_none_input():
|
||||||
|
|||||||
Reference in New Issue
Block a user