Compare commits

..

9 Commits

Author SHA1 Message Date
Beto Dealmeida
508aad1603 Enable extension in Docker 2026-02-09 14:57:40 -05:00
Beto Dealmeida
954cf32ca4 Temporary registry 2026-02-09 14:57:10 -05:00
Beto Dealmeida
552c685a6b Fix models 2026-02-09 13:55:34 -05:00
Beto Dealmeida
a26c91c4e2 Fix mapper 2026-02-09 13:54:11 -05:00
Beto Dealmeida
3c8835bd75 Fix migration 2026-02-09 10:48:05 -05:00
Beto Dealmeida
955d8bc205 Frontend support 2026-02-06 19:02:51 -05:00
Beto Dealmeida
cd8e27d33c API integration 2026-02-06 16:28:08 -05:00
Beto Dealmeida
d0962bd32f feat: models and DAOs 2026-02-06 16:27:58 -05:00
Beto Dealmeida
28870168cd feat: semantic layer extension 2026-02-06 13:42:39 -05:00
21 changed files with 156 additions and 304 deletions

View File

@@ -21,6 +21,7 @@ import enum
from typing import Protocol, runtime_checkable
from superset_core.semantic_layers.types import (
AdhocFilter,
Dimension,
Filter,
GroupLimit,
@@ -68,7 +69,7 @@ class SemanticView(Protocol):
def get_values(
self,
dimension: Dimension,
filters: set[Filter] | None = None,
filters: set[Filter | AdhocFilter] | None = None,
) -> SemanticResult:
"""
Return distinct values for a dimension.
@@ -78,7 +79,7 @@ class SemanticView(Protocol):
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter] | None = None,
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
@@ -93,7 +94,7 @@ class SemanticView(Protocol):
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter] | None = None,
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,

View File

@@ -239,7 +239,6 @@ class Operator(str, enum.Enum):
NOT_LIKE = "NOT LIKE"
IS_NULL = "IS NULL"
IS_NOT_NULL = "IS NOT NULL"
ADHOC = "ADHOC"
FilterValues = str | int | float | bool | datetime | date | time | timedelta | None
@@ -253,11 +252,19 @@ class PredicateType(enum.Enum):
@dataclass(frozen=True, order=True)
class Filter:
type: PredicateType
column: Dimension | Metric | None
column: Dimension | Metric
operator: Operator
value: FilterValues | frozenset[FilterValues]
# TODO (betodealmeida): convert into Operator:
# Filter(type=..., column=None, operator=Operator.AdHoc, value="some definition")
@dataclass(frozen=True, order=True)
class AdhocFilter:
type: PredicateType
definition: str
class OrderDirection(enum.Enum):
ASC = "ASC"
DESC = "DESC"
@@ -284,7 +291,7 @@ class GroupLimit:
metric: Metric | None
direction: OrderDirection = OrderDirection.DESC
group_others: bool = False
filters: set[Filter] | None = None
filters: set[Filter | AdhocFilter] | None = None
@dataclass(frozen=True)
@@ -321,7 +328,7 @@ class SemanticQuery:
metrics: list[Metric]
dimensions: list[Dimension]
filters: set[Filter] | None = None
filters: set[Filter | AdhocFilter] | None = None
order: list[OrderTuple] | None = None
limit: int | None = None
offset: int | None = None

View File

@@ -29,13 +29,14 @@ const DATASOURCE_TYPE_MAP: Record<string, DatasourceType> = {
};
export default class DatasourceKey {
readonly id: number;
readonly id: number | string;
readonly type: DatasourceType;
constructor(key: string) {
const [idStr, typeStr] = key.split('__');
this.id = parseInt(idStr, 10);
const isNumeric = /^\d+$/.test(idStr);
this.id = isNumeric ? parseInt(idStr, 10) : idStr;
this.type = DATASOURCE_TYPE_MAP[typeStr] ?? DatasourceType.Table;
}

View File

@@ -38,7 +38,7 @@ export interface Currency {
* Datasource metadata.
*/
export interface Datasource {
id: number;
id: number | string;
name: string;
type: DatasourceType;
columns: Column[];

View File

@@ -159,7 +159,7 @@ export interface QueryObject
export interface QueryContext {
datasource: {
id: number;
id: number | string;
type: DatasourceType;
};
/** Force refresh of all queries */

View File

@@ -90,11 +90,15 @@ const ModalFooter = ({ formData, closeModal }: ModalFooterProps) => {
findPermission('can_explore', 'Superset', state.user?.roles),
);
const [datasource_id, datasource_type] = formData.datasource.split('__');
const [datasourceIdStr, datasource_type] = formData.datasource.split('__');
const isNumeric = /^\d+$/.test(datasourceIdStr);
const datasource_id = isNumeric
? parseInt(datasourceIdStr, 10)
: datasourceIdStr;
useEffect(() => {
// short circuit if the user is embedded as explore is not available
if (isEmbedded()) return;
postFormData(Number(datasource_id), datasource_type, formData, 0)
postFormData(datasource_id, datasource_type, formData, 0)
.then(key => {
setUrl(
`/explore/?form_data_key=${key}&dashboard_page_id=${dashboardPageId}`,

View File

@@ -272,7 +272,7 @@ export type Slice = {
changed_on: number;
changed_on_humanized: string;
modified: string;
datasource_id: number;
datasource_id: number | string;
datasource_type: DatasourceType;
datasource_url: string;
datasource_name: string;

View File

@@ -144,15 +144,19 @@ export const getSlicePayload = async (
...adhocFilters,
dashboards,
};
let datasourceId = 0;
let datasourceId: number | string = 0;
let datasourceType: DatasourceType = DatasourceType.Table;
if (formData.datasource) {
const [id, typeString] = formData.datasource.split('__');
datasourceId = parseInt(id, 10);
const isNumeric = /^\d+$/.test(id);
datasourceId = isNumeric ? parseInt(id, 10) : id;
if (Object.values(DatasourceType).includes(typeString as DatasourceType)) {
datasourceType = typeString as DatasourceType;
const formattedTypeString =
typeString.charAt(0).toUpperCase() + typeString.slice(1);
if (formattedTypeString in DatasourceType) {
datasourceType =
DatasourceType[formattedTypeString as keyof typeof DatasourceType];
}
}

View File

@@ -20,7 +20,7 @@ import { SupersetClient, JsonObject, JsonResponse } from '@superset-ui/core';
import { sanitizeFormData } from 'src/utils/sanitizeFormData';
type Payload = {
datasource_id: number;
datasource_id: number | string;
datasource_type: string;
form_data: string;
chart_id?: number;
@@ -36,7 +36,7 @@ const assembleEndpoint = (key?: string, tabId?: string) => {
};
const assemblePayload = (
datasourceId: number,
datasourceId: number | string,
datasourceType: string,
formData: JsonObject,
chartId?: number,
@@ -53,7 +53,7 @@ const assemblePayload = (
};
export const postFormData = (
datasourceId: number,
datasourceId: number | string,
datasourceType: string,
formData: JsonObject,
chartId?: number,
@@ -70,7 +70,7 @@ export const postFormData = (
}).then((r: JsonResponse) => r.json.key);
export const putFormData = (
datasourceId: number,
datasourceId: number | string,
datasourceType: string,
key: string,
formData: JsonObject,

View File

@@ -18,6 +18,7 @@ import contextlib
import logging
from abc import ABC
from typing import Any, cast, Optional
from uuid import UUID
from flask import request
from flask_babel import lazy_gettext as _
@@ -100,9 +101,12 @@ class GetExploreCommand(BaseCommand, ABC):
use_slice_data=True,
initial_form_data=initial_form_data,
)
ds_id: int | UUID | None = None
try:
self._datasource_id, self._datasource_type = get_datasource_info(
self._datasource_id, self._datasource_type, form_data
ds_id, self._datasource_type = get_datasource_info(
self._datasource_id,
self._datasource_type,
form_data,
)
except SupersetException:
self._datasource_id = None
@@ -111,10 +115,11 @@ class GetExploreCommand(BaseCommand, ABC):
datasource: Optional[BaseDatasource] = None
if self._datasource_id is not None:
if ds_id is not None:
with contextlib.suppress(DatasourceNotFound):
datasource = DatasourceDAO.get_datasource(
cast(str, self._datasource_type), self._datasource_id
cast(str, self._datasource_type),
ds_id,
)
datasource_name = _("[Missing Dataset]")
@@ -124,7 +129,11 @@ class GetExploreCommand(BaseCommand, ABC):
security_manager.raise_for_access(datasource=datasource)
viz_type = form_data.get("viz_type")
if not viz_type and datasource and getattr(datasource, "default_endpoint", None):
if (
not viz_type
and datasource
and getattr(datasource, "default_endpoint", None)
):
raise WrongEndpointError(redirect=datasource.default_endpoint)
form_data["datasource"] = (

View File

@@ -14,14 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
@dataclass
class CommandParameters:
permalink_key: Optional[str]
form_data_key: Optional[str]
datasource_id: Optional[int]
datasource_type: Optional[str]
permalink_key: str | None
form_data_key: str | None
datasource_id: int | str | None
datasource_type: str | None
slice_id: Optional[int]

View File

@@ -1,99 +0,0 @@
"""
Script to create a Pandas semantic layer and Sales semantic view in Superset.
Run this inside the superset_app container:
python /app/superset/create_pandas_semantic_layer.py
"""
from __future__ import annotations
import logging
import sys
from typing import TYPE_CHECKING
# Add the Superset application directory to the Python path
sys.path.insert(0, "/app")
from superset.app import create_app
from superset.extensions import db
from superset.utils import json
if TYPE_CHECKING:
from superset.semantic_layers.models import SemanticLayer, SemanticView
app = create_app()
app.app_context().push()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
def create_pandas_semantic_layer() -> SemanticLayer:
"""Create a Pandas semantic layer with minimal configuration."""
from superset.semantic_layers.models import SemanticLayer
logger.info("Creating Pandas semantic layer...")
configuration = {
"dataset": "sales",
}
semantic_layer = SemanticLayer(
name="Pandas Semantic Layer",
description="In-memory semantic layer backed by a Pandas DataFrame",
type="pandas",
configuration=json.dumps(configuration),
cache_timeout=3600,
)
db.session.add(semantic_layer)
db.session.commit()
logger.info("Created semantic layer:")
logger.info(" Name: %s", semantic_layer.name)
logger.info(" UUID: %s", semantic_layer.uuid)
logger.info(" Type: %s", semantic_layer.type)
return semantic_layer
def create_sales_semantic_view(semantic_layer: SemanticLayer) -> SemanticView:
"""Create the Sales semantic view."""
from superset.semantic_layers.models import SemanticView
logger.info("Creating Sales semantic view...")
semantic_view = SemanticView(
name="sales",
configuration="{}",
cache_timeout=1800,
semantic_layer_uuid=semantic_layer.uuid,
)
db.session.add(semantic_view)
db.session.commit()
logger.info("Created semantic view:")
logger.info(" Name: %s", semantic_view.name)
logger.info(" UUID: %s", semantic_view.uuid)
logger.info(" Semantic Layer UUID: %s", semantic_view.semantic_layer_uuid)
return semantic_view
def main() -> None:
"""Main script execution."""
logger.info("=" * 60)
logger.info("Creating Pandas Semantic Layer and Sales Semantic View")
logger.info("=" * 60)
semantic_layer = create_pandas_semantic_layer()
create_sales_semantic_view(semantic_layer)
if __name__ == "__main__":
main()

View File

@@ -16,8 +16,8 @@
# under the License.
import logging
import uuid
from typing import Union
from uuid import UUID
from superset import db
from superset.connectors.sqla.models import SqlaTable
@@ -48,7 +48,7 @@ class DatasourceDAO(BaseDAO[Datasource]):
def get_datasource(
cls,
datasource_type: Union[DatasourceType, str],
database_id_or_uuid: int | str,
database_id_or_uuid: int | str | UUID,
) -> Datasource:
if datasource_type not in cls.sources:
raise DatasourceTypeNotSupportedError()
@@ -59,7 +59,7 @@ class DatasourceDAO(BaseDAO[Datasource]):
filter = model.id == int(database_id_or_uuid)
else:
try:
uuid.UUID(str(database_id_or_uuid)) # uuid validation
UUID(str(database_id_or_uuid)) # uuid validation
filter = model.uuid == database_id_or_uuid
except ValueError as err:
logger.warning(

View File

@@ -109,7 +109,7 @@ class ExploreRestApi(BaseSupersetApi):
params = CommandParameters(
permalink_key=request.args.get("permalink_key", type=str),
form_data_key=request.args.get("form_data_key", type=str),
datasource_id=request.args.get("datasource_id", type=int),
datasource_id=request.args.get("datasource_id"),
datasource_type=request.args.get("datasource_type", type=str),
slice_id=request.args.get("slice_id", type=int),
)

View File

@@ -25,7 +25,6 @@ Create Date: 2025-11-04 11:26:00.000000
import uuid
import sqlalchemy as sa
from alembic import op
from sqlalchemy_utils import UUIDType
from sqlalchemy_utils.types.json import JSONType
@@ -83,7 +82,6 @@ def upgrade():
create_table(
"semantic_views",
sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False),
sa.Column("id", sa.Integer(), sa.Identity(), unique=True, nullable=False),
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("name", sa.String(length=250), nullable=False),
@@ -123,22 +121,6 @@ def upgrade():
)
# Update chart datasource constraint to allow semantic_view
with op.batch_alter_table("slices") as batch_op:
batch_op.drop_constraint("ck_chart_datasource", type_="check")
batch_op.create_check_constraint(
"ck_chart_datasource",
"datasource_type in ('table', 'semantic_view')",
)
def downgrade():
# Restore original constraint
with op.batch_alter_table("slices") as batch_op:
batch_op.drop_constraint("ck_chart_datasource", type_="check")
batch_op.create_check_constraint(
"ck_chart_datasource", "datasource_type in ('table')"
)
drop_table("semantic_views")
drop_table("semantic_layers")

View File

@@ -32,6 +32,7 @@ import numpy as np
from superset_core.semantic_layers.semantic_view import SemanticViewFeature
from superset_core.semantic_layers.types import (
AdhocExpression,
AdhocFilter,
Day,
Dimension,
Filter,
@@ -369,14 +370,14 @@ def _get_filters_from_query_object(
query_object: ValidatedQueryObject,
time_offset: str | None,
all_dimensions: dict[str, Dimension],
) -> set[Filter]:
) -> set[Filter | AdhocFilter]:
"""
Extract all filters from the query object, including time range filters.
This simplifies the complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm
by converting all time constraints into filters.
"""
filters: set[Filter] = set()
filters: set[Filter | AdhocFilter] = set()
# 1. Add fetch values predicate if present
if (
@@ -384,11 +385,9 @@ def _get_filters_from_query_object(
and query_object.datasource.fetch_values_predicate
):
filters.add(
Filter(
AdhocFilter(
type=PredicateType.WHERE,
column=None,
operator=Operator.ADHOC,
value=query_object.datasource.fetch_values_predicate,
definition=query_object.datasource.fetch_values_predicate,
)
)
@@ -416,7 +415,7 @@ def _get_filters_from_query_object(
return filters
def _get_filters_from_extras(extras: dict[str, Any]) -> set[Filter]:
def _get_filters_from_extras(extras: dict[str, Any]) -> set[AdhocFilter]:
"""
Extract filters from the extras dict.
@@ -431,29 +430,25 @@ def _get_filters_from_extras(extras: dict[str, Any]) -> set[Filter]:
Handled in _convert_time_grain() and used for dimension grain matching
Note: The WHERE and HAVING clauses from extras are SQL expressions that
are passed through as-is to the semantic layer as adhoc Filter objects.
are passed through as-is to the semantic layer as AdhocFilter objects.
"""
filters: set[Filter] = set()
filters: set[AdhocFilter] = set()
# Add WHERE clause from extras
if where_clause := extras.get("where"):
filters.add(
Filter(
AdhocFilter(
type=PredicateType.WHERE,
column=None,
operator=Operator.ADHOC,
value=where_clause,
definition=where_clause,
)
)
# Add HAVING clause from extras
if having_clause := extras.get("having"):
filters.add(
Filter(
AdhocFilter(
type=PredicateType.HAVING,
column=None,
operator=Operator.ADHOC,
value=having_clause,
definition=having_clause,
)
)
@@ -545,7 +540,7 @@ def _convert_query_object_filter(
all_dimensions: dict[str, Dimension],
) -> set[Filter] | None:
"""
Convert a QueryObject filter dict to a semantic layer Filter.
Convert a QueryObject filter dict to a semantic layer Filter or AdhocFilter.
"""
operator_str = filter_["op"]
@@ -569,6 +564,7 @@ def _convert_query_object_filter(
if operator_str == FilterOperator.TEMPORAL_RANGE.value:
if not isinstance(value, str) or value == NO_TIME_RANGE:
return None
start, end = value.split(" : ")
return {
Filter(
@@ -681,7 +677,7 @@ def _get_group_limit_from_query_object(
def _get_group_limit_filters(
query_object: ValidatedQueryObject,
all_dimensions: dict[str, Dimension],
) -> set[Filter] | None:
) -> set[Filter | AdhocFilter] | None:
"""
Get separate filters for the group limit subquery if needed.
@@ -704,7 +700,7 @@ def _get_group_limit_filters(
return None
# Create separate filters for the group limit subquery
filters: set[Filter] = set()
filters: set[Filter | AdhocFilter] = set()
# Add time range filter using inner bounds
if query_object.granularity:
@@ -737,11 +733,9 @@ def _get_group_limit_filters(
and query_object.datasource.fetch_values_predicate
):
filters.add(
Filter(
AdhocFilter(
type=PredicateType.WHERE,
column=None,
operator=Operator.ADHOC,
value=query_object.datasource.fetch_values_predicate,
definition=query_object.datasource.fetch_values_predicate,
)
)

View File

@@ -26,7 +26,7 @@ from functools import cached_property
from typing import Any, TYPE_CHECKING
from flask_appbuilder import Model
from sqlalchemy import Column, ForeignKey, Identity, Integer, String, Text
from sqlalchemy import Column, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship
from sqlalchemy_utils import UUIDType
from sqlalchemy_utils.types.json import JSONType
@@ -161,7 +161,6 @@ class SemanticView(AuditMixinNullable, Model):
__tablename__ = "semantic_views"
uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4)
id = Column(Integer, Identity(), unique=True)
# Core fields
name = Column(String(250), nullable=False)
@@ -248,7 +247,7 @@ class SemanticView(AuditMixinNullable, Model):
def data(self) -> ExplorableData:
return {
# core
"id": self.id,
"id": self.uuid.hex,
"uid": self.uid,
"type": "semantic_view",
"name": self.name,
@@ -336,9 +335,6 @@ class SemanticView(AuditMixinNullable, Model):
"health_check_message": None,
}
def data_for_slices(self, slices: list[Any]) -> dict[str, Any]:
return self.data
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
return []
@@ -346,26 +342,6 @@ class SemanticView(AuditMixinNullable, Model):
def perm(self) -> str:
return self.semantic_layer_uuid.hex + "::" + self.uuid.hex
@property
def catalog_perm(self) -> str | None:
return None
@property
def schema_perm(self) -> str | None:
return None
@property
def schema(self) -> str | None:
return None
@property
def url(self) -> str:
return f"/semantic_view/{self.uuid}/"
@property
def explore_url(self) -> str:
return f"/explore/?datasource_type=semantic_view&datasource_id={self.id}"
@property
def offset(self) -> int:
# always return datetime as UTC

File diff suppressed because one or more lines are too long

View File

@@ -24,6 +24,7 @@ import re
from datetime import datetime
from typing import Any, Callable, cast
from urllib import parse
from uuid import UUID
from flask import (
abort,
@@ -169,9 +170,9 @@ class Superset(BaseSupersetView):
if viz_obj.has_error(payload):
return json_error_response(payload=payload, status=400)
response = {
"data": payload["df"].to_dict("records")
if payload["df"] is not None
else [],
"data": (
payload["df"].to_dict("records") if payload["df"] is not None else []
),
"colnames": payload.get("colnames"),
"coltypes": payload.get("coltypes"),
"rowcount": payload.get("rowcount"),
@@ -268,7 +269,9 @@ class Superset(BaseSupersetView):
@check_resource_permissions(check_datasource_perms)
@deprecated(eol_version="5.0.0")
def explore_json(
self, datasource_type: str | None = None, datasource_id: int | None = None
self,
datasource_type: str | None = None,
datasource_id: int | str | None = None,
) -> FlaskResponse:
"""Serves all request that GET or POST form_data
@@ -302,8 +305,10 @@ class Superset(BaseSupersetView):
form_data = get_form_data()[0]
try:
datasource_id, datasource_type = get_datasource_info(
datasource_id, datasource_type, form_data
ds_id, datasource_type = get_datasource_info(
datasource_id,
datasource_type,
form_data,
)
force = request.args.get("force") == "true"
@@ -316,7 +321,7 @@ class Superset(BaseSupersetView):
with contextlib.suppress(CacheLoadError):
viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
datasource_id=ds_id,
form_data=form_data,
force_cached=True,
force=force,
@@ -343,7 +348,7 @@ class Superset(BaseSupersetView):
viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
datasource_id=ds_id,
form_data=form_data,
force=force,
)
@@ -407,7 +412,7 @@ class Superset(BaseSupersetView):
def explore( # noqa: C901
self,
datasource_type: str | None = None,
datasource_id: int | None = None,
datasource_id: int | str | None = None,
key: str | None = None,
) -> FlaskResponse:
if request.method == "GET":
@@ -451,21 +456,23 @@ class Superset(BaseSupersetView):
query_context = request.form.get("query_context")
ds_id: int | UUID | None = None
try:
datasource_id, datasource_type = get_datasource_info(
datasource_id, datasource_type, form_data
ds_id, datasource_type = get_datasource_info(
datasource_id,
datasource_type,
form_data,
)
except SupersetException:
datasource_id = None
# fallback unknown datasource to table type
datasource_type = SqlaTable.type
datasource: BaseDatasource | None = None
if datasource_id is not None:
if ds_id is not None:
with contextlib.suppress(DatasetNotFoundError):
datasource = DatasourceDAO.get_datasource(
DatasourceType("table"),
datasource_id,
ds_id,
)
datasource_name = datasource.name if datasource else _("[Missing Dataset]")

View File

@@ -14,12 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import contextlib
import logging
from collections import defaultdict
from functools import wraps
from typing import Any, Callable, DefaultDict, Optional, Union
from typing import Any, Callable, DefaultDict
from urllib import parse
from uuid import UUID
import msgpack
import pyarrow as pa
@@ -163,7 +167,7 @@ def get_permissions(
def get_viz(
form_data: FormData,
datasource_type: str,
datasource_id: int,
datasource_id: int | UUID,
force: bool = False,
force_cached: bool = False,
) -> BaseViz:
@@ -186,10 +190,10 @@ def loads_request_json(request_json_data: str) -> dict[Any, Any]:
def get_form_data(
slice_id: Optional[int] = None,
slice_id: int | None = None,
use_slice_data: bool = False,
initial_form_data: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], Optional[Slice]]:
initial_form_data: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], Slice | None]:
form_data: dict[str, Any] = initial_form_data or {}
if has_request_context():
@@ -272,8 +276,10 @@ def add_sqllab_custom_filters(form_data: dict[Any, Any]) -> Any:
def get_datasource_info(
datasource_id: Optional[int], datasource_type: Optional[str], form_data: FormData
) -> tuple[int, Optional[str]]:
datasource_id: int | str | None,
datasource_type: str | None,
form_data: FormData,
) -> tuple[int | UUID, str | None]:
"""
Compatibility layer for handling of datasource info
@@ -300,12 +306,16 @@ def get_datasource_info(
_("The dataset associated with this chart no longer exists")
)
datasource_id = int(datasource_id)
return datasource_id, datasource_type
# Convert datasource_id to appropriate type
if isinstance(datasource_id, int):
return datasource_id, datasource_type
if datasource_id.isdigit():
return int(datasource_id), datasource_type
return UUID(datasource_id), datasource_type
def apply_display_max_row_limit(
sql_results: dict[str, Any], rows: Optional[int] = None
sql_results: dict[str, Any], rows: int | None = None
) -> dict[str, Any]:
"""
Given a `sql_results` nested structure, applies a limit to the number of rows
@@ -482,8 +492,8 @@ def check_explore_cache_perms(_self: Any, cache_key: str) -> None:
def check_datasource_perms(
_self: Any,
datasource_type: Optional[str] = None,
datasource_id: Optional[int] = None,
datasource_type: str | None = None,
datasource_id: int | str | None = None,
**kwargs: Any,
) -> None:
"""
@@ -500,8 +510,10 @@ def check_datasource_perms(
form_data = kwargs["form_data"] if "form_data" in kwargs else get_form_data()[0]
try:
datasource_id, datasource_type = get_datasource_info(
datasource_id, datasource_type, form_data
ds_id, datasource_type = get_datasource_info(
datasource_id,
datasource_type,
form_data,
)
except SupersetException as ex:
raise SupersetSecurityException(
@@ -524,7 +536,7 @@ def check_datasource_perms(
try:
viz_obj = get_viz(
datasource_type=datasource_type,
datasource_id=datasource_id,
datasource_id=ds_id,
form_data=form_data,
force=False,
)
@@ -541,7 +553,9 @@ def check_datasource_perms(
def _deserialize_results_payload(
payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False
payload: bytes | str,
query: Query,
use_msgpack: bool | None = False,
) -> dict[str, Any]:
logger.debug("Deserializing from msgpack: %r", use_msgpack)
if use_msgpack:
@@ -579,9 +593,12 @@ def _deserialize_results_payload(
def get_cta_schema_name(
database: Database, user: ab_models.User, schema: str, sql: str
) -> Optional[str]:
func: Optional[Callable[[Database, ab_models.User, str, str], str]] = app.config[
database: Database,
user: ab_models.User,
schema: str,
sql: str,
) -> str | None:
func: Callable[[Database, ab_models.User, str, str], str] | None = app.config[
"SQLLAB_CTAS_SCHEMA_NAME_FUNC"
]
if not func:

View File

@@ -24,6 +24,7 @@ from pytest_mock import MockerFixture
from superset_core.semantic_layers.semantic_view import SemanticViewFeature
from superset_core.semantic_layers.types import (
AdhocExpression,
AdhocFilter,
Day,
Dimension,
Filter,
@@ -201,11 +202,9 @@ def test_get_filters_from_extras_where() -> None:
assert len(result) == 1
filter_ = next(iter(result))
assert isinstance(filter_, Filter)
assert isinstance(filter_, AdhocFilter)
assert filter_.type == PredicateType.WHERE
assert filter_.column is None
assert filter_.operator == Operator.ADHOC
assert filter_.value == "customer_id > 100"
assert filter_.definition == "customer_id > 100"
def test_get_filters_from_extras_having() -> None:
@@ -216,12 +215,7 @@ def test_get_filters_from_extras_having() -> None:
result = _get_filters_from_extras(extras)
assert result == {
Filter(
type=PredicateType.HAVING,
column=None,
operator=Operator.ADHOC,
value="SUM(sales) > 1000",
),
AdhocFilter(type=PredicateType.HAVING, definition="SUM(sales) > 1000"),
}
@@ -236,18 +230,8 @@ def test_get_filters_from_extras_both() -> None:
result = _get_filters_from_extras(extras)
assert result == {
Filter(
type=PredicateType.WHERE,
column=None,
operator=Operator.ADHOC,
value="region = 'US'",
),
Filter(
type=PredicateType.HAVING,
column=None,
operator=Operator.ADHOC,
value="COUNT(*) > 10",
),
AdhocFilter(type=PredicateType.WHERE, definition="region = 'US'"),
AdhocFilter(type=PredicateType.HAVING, definition="COUNT(*) > 10"),
}
@@ -466,11 +450,9 @@ def test_get_filters_from_query_object_with_extras(mock_datasource: MagicMock) -
operator=Operator.LESS_THAN,
value=datetime(2025, 10, 22),
),
Filter(
AdhocFilter(
type=PredicateType.WHERE,
column=None,
operator=Operator.ADHOC,
value="customer_id > 100",
definition="customer_id > 100",
),
}
@@ -512,11 +494,9 @@ def test_get_filters_from_query_object_with_fetch_values(
operator=Operator.LESS_THAN,
value=datetime(2025, 10, 22),
),
Filter(
AdhocFilter(
type=PredicateType.WHERE,
column=None,
operator=Operator.ADHOC,
value="tenant_id = 123",
definition="tenant_id = 123",
),
}
@@ -816,11 +796,9 @@ def test_get_group_limit_filters_with_extras(mock_datasource: MagicMock) -> None
operator=Operator.LESS_THAN,
value=datetime(2025, 10, 22),
),
Filter(
AdhocFilter(
type=PredicateType.WHERE,
column=None,
operator=Operator.ADHOC,
value="customer_id > 100",
definition="customer_id > 100",
),
}
@@ -2041,11 +2019,9 @@ def test_get_group_limit_filters_with_fetch_values_predicate(
assert result is not None
assert (
Filter(
AdhocFilter(
type=PredicateType.WHERE,
column=None,
operator=Operator.ADHOC,
value="tenant_id = 123",
definition="tenant_id = 123",
)
in result
)
@@ -2396,7 +2372,6 @@ def test_get_filters_from_query_object_with_filter_loop(
f
for f in result
if isinstance(f, Filter)
and f.column
and f.column.name == "category"
and f.operator == Operator.EQUALS
]
@@ -2469,7 +2444,6 @@ def test_get_group_limit_filters_with_filter_loop(
f
for f in result
if isinstance(f, Filter)
and f.column
and f.column.name == "category"
and f.operator == Operator.EQUALS
]
@@ -2581,7 +2555,6 @@ def test_get_filters_from_query_object_filter_returns_none(
f
for f in result
if isinstance(f, Filter)
and f.column
and f.column.name == "category"
and f.operator == Operator.EQUALS
]
@@ -2634,7 +2607,6 @@ def test_get_group_limit_filters_filter_returns_none(
f
for f in result
if isinstance(f, Filter)
and f.column
and f.column.name == "category"
and f.operator == Operator.EQUALS
]