mirror of
https://github.com/apache/superset.git
synced 2026-04-29 04:54:21 +00:00
Compare commits
27 Commits
docs/mcp-s
...
backup/sem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e046a857c | ||
|
|
36554237aa | ||
|
|
6f93e1cbb1 | ||
|
|
913259299e | ||
|
|
2351e0ead7 | ||
|
|
8c6f211003 | ||
|
|
0e3d78817f | ||
|
|
f0c8304e24 | ||
|
|
80233aed46 | ||
|
|
6f350428df | ||
|
|
548ccfde44 | ||
|
|
596008203c | ||
|
|
ff46c86df3 | ||
|
|
4e30638024 | ||
|
|
efa9159cc8 | ||
|
|
14668f37bd | ||
|
|
27a2466855 | ||
|
|
e35c6946ec | ||
|
|
12c5bfa0a5 | ||
|
|
0303a234a3 | ||
|
|
09e9927652 | ||
|
|
3f9ea361bb | ||
|
|
f1047140ee | ||
|
|
15e3ab4493 | ||
|
|
755aa2e32f | ||
|
|
17d1ed7353 | ||
|
|
9c1bcb70d0 |
@@ -52,6 +52,7 @@ jobs:
|
||||
SUPERSET_SECRET_KEY: not-a-secret
|
||||
run: |
|
||||
pytest --durations-min=0.5 --cov=superset/sql/ ./tests/unit_tests/sql/ --cache-clear --cov-fail-under=100
|
||||
pytest --durations-min=0.5 --cov=superset/semantic_layers/ ./tests/unit_tests/semantic_layers/ --cache-clear --cov-fail-under=100
|
||||
- name: Upload code coverage
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
|
||||
@@ -24,6 +24,14 @@ assists people when migrating to a new version.
|
||||
|
||||
## Next
|
||||
|
||||
### Combined datasource list endpoint
|
||||
|
||||
Added a new combined datasource list endpoint at `GET /api/v1/datasource/` to serve datasets and semantic views in one response.
|
||||
|
||||
- The endpoint is available to users with at least one of `can_read` on `Dataset` or `SemanticView`.
|
||||
- Semantic views are included only when the `SEMANTIC_LAYERS` feature flag is enabled.
|
||||
- The endpoint enforces strict `order_column` validation and returns `400` for invalid sort columns.
|
||||
|
||||
### ClickHouse minimum driver version bump
|
||||
|
||||
The minimum required version of `clickhouse-connect` has been raised to `>=0.13.0`. If you are using the ClickHouse connector, please upgrade your `clickhouse-connect` package. The `_mutate_label` workaround that appended hash suffixes to column aliases has also been removed, as it is no longer needed with modern versions of the driver.
|
||||
|
||||
@@ -224,3 +224,52 @@ async def analysis_guide(ctx: Context) -> str:
|
||||
```
|
||||
|
||||
See [MCP Integration](./mcp) for implementation details.
|
||||
|
||||
### Semantic Layers
|
||||
|
||||
Extensions can register custom semantic layer implementations that allow Superset to connect to external data modeling frameworks. Each semantic layer defines how to authenticate, discover semantic views (tables/metrics/dimensions), and execute queries against the external system.
|
||||
|
||||
```python
|
||||
from superset_core.semantic_layers.decorators import semantic_layer
|
||||
from superset_core.semantic_layers.layer import SemanticLayer
|
||||
|
||||
from my_extension.config import MyConfig
|
||||
from my_extension.view import MySemanticView
|
||||
|
||||
|
||||
@semantic_layer(
|
||||
id="my_platform",
|
||||
name="My Data Platform",
|
||||
description="Connect to My Data Platform's semantic layer",
|
||||
)
|
||||
class MySemanticLayer(SemanticLayer[MyConfig, MySemanticView]):
|
||||
configuration_class = MyConfig
|
||||
|
||||
@classmethod
|
||||
def from_configuration(cls, configuration: dict) -> "MySemanticLayer":
|
||||
config = MyConfig.model_validate(configuration)
|
||||
return cls(config)
|
||||
|
||||
@classmethod
|
||||
def get_configuration_schema(cls, configuration=None) -> dict:
|
||||
return MyConfig.model_json_schema()
|
||||
|
||||
@classmethod
|
||||
def get_runtime_schema(cls, configuration=None, runtime_data=None) -> dict:
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
def get_semantic_views(self, runtime_configuration: dict) -> set[MySemanticView]:
|
||||
# Return available views from the external platform
|
||||
...
|
||||
|
||||
def get_semantic_view(self, name: str, additional_configuration: dict) -> MySemanticView:
|
||||
# Return a specific view by name
|
||||
...
|
||||
```
|
||||
|
||||
**Note**: The `@semantic_layer` decorator automatically detects context and applies appropriate ID prefixing:
|
||||
|
||||
- **Extension context**: ID prefixed as `extensions.{publisher}.{name}.{id}`
|
||||
- **Host context**: Original ID used as-is
|
||||
|
||||
The decorator registers the class in the semantic layers registry, making it available in the UI for users to create connections. The `configuration_class` should be a Pydantic model that defines the fields needed to connect (credentials, project, database, etc.). Superset uses the model's JSON schema to render the configuration form dynamically.
|
||||
|
||||
6
docs/static/feature-flags.json
vendored
6
docs/static/feature-flags.json
vendored
@@ -75,6 +75,12 @@
|
||||
"lifecycle": "development",
|
||||
"description": "Expand nested types in Presto into extra columns/arrays. Experimental, doesn't work with all nested types."
|
||||
},
|
||||
{
|
||||
"name": "SEMANTIC_LAYERS",
|
||||
"default": false,
|
||||
"lifecycle": "development",
|
||||
"description": "Enable semantic layers and show semantic views alongside datasets"
|
||||
},
|
||||
{
|
||||
"name": "TABLE_V2_TIME_COMPARISON_ENABLED",
|
||||
"default": false,
|
||||
|
||||
@@ -285,6 +285,7 @@ module = [
|
||||
"superset.tags.filters",
|
||||
"superset.commands.security.update",
|
||||
"superset.commands.security.create",
|
||||
"superset.semantic_layers.api",
|
||||
]
|
||||
warn_unused_ignores = false
|
||||
|
||||
|
||||
@@ -43,6 +43,8 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"flask-appbuilder>=5.0.2,<6",
|
||||
"isodate>=0.7.0",
|
||||
"pyarrow>=16.0.0",
|
||||
"pydantic>=2.8.0",
|
||||
"sqlalchemy>=1.4.0,<2.0",
|
||||
"sqlalchemy-utils>=0.38.0, <0.43", # expanding lowerbound to work with pydoris
|
||||
|
||||
169
superset-core/src/superset_core/semantic_layers/daos.py
Normal file
169
superset-core/src/superset_core/semantic_layers/daos.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Semantic layer DAO interfaces for superset-core.
|
||||
|
||||
Provides abstract DAO classes for semantic layers and views that define the
|
||||
interface contract. Host implementations replace these with concrete classes
|
||||
backed by SQLAlchemy during initialization.
|
||||
|
||||
Usage:
|
||||
from superset_core.semantic_layers.daos import (
|
||||
AbstractSemanticLayerDAO,
|
||||
AbstractSemanticViewDAO,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from superset_core.common.daos import BaseDAO
|
||||
from superset_core.semantic_layers.models import SemanticLayerModel, SemanticViewModel
|
||||
|
||||
|
||||
class AbstractSemanticLayerDAO(BaseDAO[SemanticLayerModel]):
|
||||
"""
|
||||
Abstract DAO interface for SemanticLayer.
|
||||
|
||||
Host implementations will replace this class during initialization
|
||||
with a concrete DAO providing actual database access.
|
||||
"""
|
||||
|
||||
model_cls: ClassVar[type[Any] | None] = None
|
||||
base_filter = None
|
||||
id_column_name = "uuid"
|
||||
uuid_column_name = "uuid"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_uniqueness(cls, name: str) -> bool:
|
||||
"""
|
||||
Validate that a semantic layer name is unique.
|
||||
|
||||
:param name: Semantic layer name to validate
|
||||
:return: True if the name is unique, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_update_uniqueness(cls, layer_uuid: str, name: str) -> bool:
|
||||
"""
|
||||
Validate that a semantic layer name is unique for an update operation,
|
||||
excluding the layer being updated.
|
||||
|
||||
:param layer_uuid: UUID of the semantic layer being updated
|
||||
:param name: New name to validate
|
||||
:return: True if the name is unique, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def find_by_name(cls, name: str) -> SemanticLayerModel | None:
|
||||
"""
|
||||
Find a semantic layer by name.
|
||||
|
||||
:param name: Semantic layer name
|
||||
:return: SemanticLayerModel instance or None
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_semantic_views(cls, layer_uuid: str) -> list[SemanticViewModel]:
|
||||
"""
|
||||
Get all semantic views associated with a semantic layer.
|
||||
|
||||
:param layer_uuid: UUID of the semantic layer
|
||||
:return: List of SemanticViewModel instances
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AbstractSemanticViewDAO(BaseDAO[SemanticViewModel]):
|
||||
"""
|
||||
Abstract DAO interface for SemanticView.
|
||||
|
||||
Host implementations will replace this class during initialization
|
||||
with a concrete DAO providing actual database access.
|
||||
"""
|
||||
|
||||
model_cls: ClassVar[type[Any] | None] = None
|
||||
base_filter = None
|
||||
id_column_name = "id"
|
||||
uuid_column_name = "uuid"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_uniqueness(
|
||||
cls,
|
||||
name: str,
|
||||
layer_uuid: str,
|
||||
configuration: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that a semantic view is unique within a semantic layer.
|
||||
|
||||
Uniqueness is determined by the combination of name, layer UUID, and
|
||||
configuration.
|
||||
|
||||
:param name: View name
|
||||
:param layer_uuid: UUID of the parent semantic layer
|
||||
:param configuration: Configuration dict to compare
|
||||
:return: True if unique, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_update_uniqueness(
|
||||
cls,
|
||||
view_uuid: str,
|
||||
name: str,
|
||||
layer_uuid: str,
|
||||
configuration: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that a semantic view is unique within a semantic layer for an
|
||||
update operation, excluding the view being updated.
|
||||
|
||||
:param view_uuid: UUID of the view being updated
|
||||
:param name: New name to validate
|
||||
:param layer_uuid: UUID of the parent semantic layer
|
||||
:param configuration: Configuration dict to compare
|
||||
:return: True if unique, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def find_by_name(cls, name: str, layer_uuid: str) -> SemanticViewModel | None:
|
||||
"""
|
||||
Find a semantic view by name within a semantic layer.
|
||||
|
||||
:param name: View name
|
||||
:param layer_uuid: UUID of the parent semantic layer
|
||||
:return: SemanticViewModel instance or None
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["AbstractSemanticLayerDAO", "AbstractSemanticViewDAO"]
|
||||
102
superset-core/src/superset_core/semantic_layers/decorators.py
Normal file
102
superset-core/src/superset_core/semantic_layers/decorators.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Semantic layer registration decorator for Superset.
|
||||
|
||||
This module provides a decorator interface to register semantic layer
|
||||
implementations with the host application, enabling automatic discovery
|
||||
by the extensions framework.
|
||||
|
||||
Usage:
|
||||
from superset_core.semantic_layers.decorators import semantic_layer
|
||||
|
||||
@semantic_layer(
|
||||
id="snowflake",
|
||||
name="Snowflake Cortex",
|
||||
description="Snowflake semantic layer via Cortex Analyst",
|
||||
)
|
||||
class SnowflakeSemanticLayer(SemanticLayer[SnowflakeConfig, SnowflakeView]):
|
||||
...
|
||||
|
||||
# Or with minimal arguments:
|
||||
@semantic_layer(id="dbt", name="dbt Semantic Layer")
|
||||
class DbtSemanticLayer(SemanticLayer[DbtConfig, DbtView]):
|
||||
...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
# Type variable for decorated semantic layer classes
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def semantic_layer(
|
||||
id: str,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
) -> Callable[[T], T]:
|
||||
"""
|
||||
Decorator to register a semantic layer implementation.
|
||||
|
||||
Automatically detects extension context and applies appropriate
|
||||
namespacing to prevent ID conflicts between host and extension
|
||||
semantic layers.
|
||||
|
||||
Host implementations will replace this function during initialization
|
||||
with a concrete implementation providing actual functionality.
|
||||
|
||||
Args:
|
||||
id: Unique semantic layer type identifier (e.g., "snowflake",
|
||||
"dbt"). Used as the key in the semantic layers registry and
|
||||
stored in the ``type`` column of the ``SemanticLayer`` model.
|
||||
name: Human-readable display name (e.g., "Snowflake Cortex").
|
||||
Shown in the UI when listing available semantic layer types.
|
||||
description: Optional description for documentation and UI
|
||||
tooltips.
|
||||
|
||||
Returns:
|
||||
Decorated semantic layer class registered with the host
|
||||
application.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If called before host implementation is
|
||||
initialized.
|
||||
|
||||
Example:
|
||||
from superset_core.semantic_layers.decorators import semantic_layer
|
||||
from superset_core.semantic_layers.layer import SemanticLayer
|
||||
|
||||
@semantic_layer(
|
||||
id="snowflake",
|
||||
name="Snowflake Cortex",
|
||||
description="Connect to Snowflake Cortex Analyst",
|
||||
)
|
||||
class SnowflakeSemanticLayer(
|
||||
SemanticLayer[SnowflakeConfig, SnowflakeView]
|
||||
):
|
||||
...
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Semantic layer decorator not initialized. "
|
||||
"This decorator should be replaced during Superset startup."
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["semantic_layer"]
|
||||
127
superset-core/src/superset_core/semantic_layers/layer.py
Normal file
127
superset-core/src/superset_core/semantic_layers/layer.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from superset_core.semantic_layers.view import SemanticView
|
||||
|
||||
ConfigT = TypeVar("ConfigT", bound=BaseModel)
|
||||
SemanticViewT = TypeVar("SemanticViewT", bound="SemanticView")
|
||||
|
||||
|
||||
class SemanticLayer(ABC, Generic[ConfigT, SemanticViewT]):
|
||||
"""
|
||||
Abstract base class for semantic layers.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_configuration(
|
||||
cls,
|
||||
configuration: dict[str, Any],
|
||||
) -> SemanticLayer[ConfigT, SemanticViewT]:
|
||||
"""
|
||||
Create a semantic layer from its configuration.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Semantic layers must implement the from_configuration method"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_configuration_schema(
|
||||
cls,
|
||||
configuration: ConfigT | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get the JSON schema for the configuration needed to add the semantic layer.
|
||||
|
||||
A partial configuration `configuration` can be sent to improve the schema,
|
||||
allowing for progressive validation and better UX. For example, a semantic
|
||||
layer might require:
|
||||
|
||||
- auth information
|
||||
- a database
|
||||
|
||||
If the user provides the auth information, a client can send the partial
|
||||
configuration to this method, and the resulting JSON schema would include
|
||||
the list of databases the user has access to, allowing a dropdown to be
|
||||
populated.
|
||||
|
||||
The Snowflake semantic layer has an example implementation of this method, where
|
||||
database and schema names are populated based on the provided connection info.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Semantic layers must implement the get_configuration_schema method"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_runtime_schema(
|
||||
cls,
|
||||
configuration: ConfigT,
|
||||
runtime_data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get the JSON schema for the runtime parameters needed to load semantic views.
|
||||
|
||||
This returns the schema needed to connect to a semantic view given the
|
||||
configuration for the semantic layer. For example, a semantic layer might
|
||||
be configured by:
|
||||
|
||||
- auth information
|
||||
- an optional database
|
||||
|
||||
If the user does not provide a database when creating the semantic layer, the
|
||||
runtime schema would require the database name to be provided before loading any
|
||||
semantic views. This allows users to create semantic layers that connect to a
|
||||
specific database (or project, account, etc.), or that allow users to select it
|
||||
at query time.
|
||||
|
||||
The Snowflake semantic layer has an example implementation of this method, where
|
||||
database and schema names are required if they were not provided in the initial
|
||||
configuration.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Semantic layers must implement the get_runtime_schema method"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_semantic_views(
|
||||
self,
|
||||
runtime_configuration: dict[str, Any],
|
||||
) -> set[SemanticViewT]:
|
||||
"""
|
||||
Get the semantic views available in the semantic layer.
|
||||
|
||||
The runtime configuration can provide information like a given project or
|
||||
schema, used to restrict the semantic views returned.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_semantic_view(
|
||||
self,
|
||||
name: str,
|
||||
additional_configuration: dict[str, Any],
|
||||
) -> SemanticViewT:
|
||||
"""
|
||||
Get a specific semantic view by its name and additional configuration.
|
||||
"""
|
||||
85
superset-core/src/superset_core/semantic_layers/models.py
Normal file
85
superset-core/src/superset_core/semantic_layers/models.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Semantic layer model interfaces for superset-core.
|
||||
|
||||
Provides abstract model classes for semantic layers and views that will be
|
||||
replaced by the host implementation's concrete SQLAlchemy models during
|
||||
initialization.
|
||||
|
||||
Usage:
|
||||
from superset_core.semantic_layers.models import (
|
||||
SemanticLayerModel,
|
||||
SemanticViewModel,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from superset_core.common.models import CoreModel
|
||||
|
||||
|
||||
class SemanticLayerModel(CoreModel):
|
||||
"""
|
||||
Abstract interface for the SemanticLayer database model.
|
||||
|
||||
Host implementations will replace this class during initialization
|
||||
with a concrete SQLAlchemy model providing actual persistence.
|
||||
"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
# Type hints for expected column attributes
|
||||
uuid: UUID
|
||||
name: str
|
||||
description: str | None
|
||||
type: str
|
||||
configuration: str
|
||||
configuration_version: int
|
||||
cache_timeout: int | None
|
||||
created_on: datetime | None
|
||||
changed_on: datetime | None
|
||||
|
||||
|
||||
class SemanticViewModel(CoreModel):
|
||||
"""
|
||||
Abstract interface for the SemanticView database model.
|
||||
|
||||
Host implementations will replace this class during initialization
|
||||
with a concrete SQLAlchemy model providing actual persistence.
|
||||
"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
# Type hints for expected column attributes
|
||||
id: int
|
||||
uuid: UUID
|
||||
name: str
|
||||
description: str | None
|
||||
configuration: str
|
||||
configuration_version: int
|
||||
cache_timeout: int | None
|
||||
semantic_layer_uuid: UUID
|
||||
created_on: datetime | None
|
||||
changed_on: datetime | None
|
||||
|
||||
|
||||
__all__ = ["SemanticLayerModel", "SemanticViewModel"]
|
||||
209
superset-core/src/superset_core/semantic_layers/types.py
Normal file
209
superset-core/src/superset_core/semantic_layers/types.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime, time, timedelta
|
||||
|
||||
import isodate
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Grain:
|
||||
"""
|
||||
Represents a time grain (e.g., day, month, year).
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name of the grain (e.g., "Second")
|
||||
representation: ISO 8601 duration (e.g., "PT1S", "P1D", "P1M")
|
||||
"""
|
||||
|
||||
name: str
|
||||
representation: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
isodate.parse_duration(self.representation)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, Grain):
|
||||
return self.representation == other.representation
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.representation)
|
||||
|
||||
|
||||
class Grains:
|
||||
"""Pre-defined common grains and factory for custom ones."""
|
||||
|
||||
SECOND = Grain("Second", "PT1S")
|
||||
MINUTE = Grain("Minute", "PT1M")
|
||||
HOUR = Grain("Hour", "PT1H")
|
||||
DAY = Grain("Day", "P1D")
|
||||
WEEK = Grain("Week", "P1W")
|
||||
MONTH = Grain("Month", "P1M")
|
||||
QUARTER = Grain("Quarter", "P3M")
|
||||
YEAR = Grain("Year", "P1Y")
|
||||
|
||||
_REGISTRY: dict[str, Grain] = {
|
||||
"PT1S": SECOND,
|
||||
"PT1M": MINUTE,
|
||||
"PT1H": HOUR,
|
||||
"P1D": DAY,
|
||||
"P1W": WEEK,
|
||||
"P1M": MONTH,
|
||||
"P3M": QUARTER,
|
||||
"P1Y": YEAR,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get(cls, representation: str, name: str | None = None) -> Grain:
|
||||
"""Return a pre-defined grain or create a custom one."""
|
||||
if grain := cls._REGISTRY.get(representation):
|
||||
return grain
|
||||
return Grain(name or representation, representation)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Dimension:
|
||||
id: str
|
||||
name: str
|
||||
type: pa.DataType
|
||||
|
||||
definition: str | None = None
|
||||
description: str | None = None
|
||||
grain: Grain | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metric:
|
||||
id: str
|
||||
name: str
|
||||
type: pa.DataType
|
||||
|
||||
definition: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AdhocExpression:
|
||||
id: str
|
||||
definition: str
|
||||
|
||||
|
||||
class Operator(str, enum.Enum):
|
||||
EQUALS = "="
|
||||
NOT_EQUALS = "!="
|
||||
GREATER_THAN = ">"
|
||||
LESS_THAN = "<"
|
||||
GREATER_THAN_OR_EQUAL = ">="
|
||||
LESS_THAN_OR_EQUAL = "<="
|
||||
IN = "IN"
|
||||
NOT_IN = "NOT IN"
|
||||
LIKE = "LIKE"
|
||||
NOT_LIKE = "NOT LIKE"
|
||||
IS_NULL = "IS NULL"
|
||||
IS_NOT_NULL = "IS NOT NULL"
|
||||
ADHOC = "ADHOC"
|
||||
|
||||
|
||||
FilterValues = str | int | float | bool | datetime | date | time | timedelta | None
|
||||
|
||||
|
||||
class PredicateType(enum.Enum):
|
||||
WHERE = "WHERE"
|
||||
HAVING = "HAVING"
|
||||
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Filter:
|
||||
type: PredicateType
|
||||
column: Dimension | Metric | None
|
||||
operator: Operator
|
||||
value: FilterValues | frozenset[FilterValues]
|
||||
|
||||
|
||||
class OrderDirection(enum.Enum):
|
||||
ASC = "ASC"
|
||||
DESC = "DESC"
|
||||
|
||||
|
||||
OrderTuple = tuple[Metric | Dimension | AdhocExpression, OrderDirection]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GroupLimit:
|
||||
"""
|
||||
Limit query to top/bottom N combinations of specified dimensions.
|
||||
|
||||
The `filters` parameter allows specifying separate filter constraints for the
|
||||
group limit subquery. This is useful when you want to determine the top N groups
|
||||
using different criteria (e.g., a different time range) than the main query.
|
||||
|
||||
For example, you might want to find the top 10 products by sales over the last
|
||||
30 days, but then show daily sales for those products over the last 7 days.
|
||||
"""
|
||||
|
||||
dimensions: list[Dimension]
|
||||
top: int
|
||||
metric: Metric | None
|
||||
direction: OrderDirection = OrderDirection.DESC
|
||||
group_others: bool = False
|
||||
filters: set[Filter] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SemanticRequest:
|
||||
"""
|
||||
Represents a request made to obtain semantic results.
|
||||
|
||||
This could be a SQL query, an HTTP request, etc.
|
||||
"""
|
||||
|
||||
type: str
|
||||
definition: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SemanticResult:
|
||||
"""
|
||||
Represents the results of a semantic query.
|
||||
|
||||
This includes any requests (SQL queries, HTTP requests) that were performed in order
|
||||
to obtain the results, in order to help troubleshooting.
|
||||
"""
|
||||
|
||||
requests: list[SemanticRequest]
|
||||
results: pa.Table
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SemanticQuery:
|
||||
"""
|
||||
Represents a semantic query.
|
||||
"""
|
||||
|
||||
metrics: list[Metric]
|
||||
dimensions: list[Dimension]
|
||||
filters: set[Filter] | None = None
|
||||
order: list[OrderTuple] | None = None
|
||||
limit: int | None = None
|
||||
offset: int | None = None
|
||||
group_limit: GroupLimit | None = None
|
||||
108
superset-core/src/superset_core/semantic_layers/view.py
Normal file
108
superset-core/src/superset_core/semantic_layers/view.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from superset_core.semantic_layers.types import (
|
||||
Dimension,
|
||||
Filter,
|
||||
Metric,
|
||||
SemanticQuery,
|
||||
SemanticResult,
|
||||
)
|
||||
|
||||
|
||||
# TODO (betodealmeida): move to the extension JSON
|
||||
class SemanticViewFeature(enum.Enum):
|
||||
"""
|
||||
Custom features supported by semantic layers.
|
||||
"""
|
||||
|
||||
ADHOC_EXPRESSIONS_IN_ORDERBY = "ADHOC_EXPRESSIONS_IN_ORDERBY"
|
||||
GROUP_LIMIT = "GROUP_LIMIT"
|
||||
GROUP_OTHERS = "GROUP_OTHERS"
|
||||
|
||||
|
||||
class SemanticView(ABC):
|
||||
"""
|
||||
Abstract base class for semantic views.
|
||||
"""
|
||||
|
||||
features: frozenset[SemanticViewFeature]
|
||||
|
||||
@abstractmethod
|
||||
def uid(self) -> str:
|
||||
"""
|
||||
Returns a unique identifier for the semantic view.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_dimensions(self) -> set[Dimension]:
|
||||
"""
|
||||
Get the dimensions defined in the semantic view.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metrics(self) -> set[Metric]:
|
||||
"""
|
||||
Get the metrics defined in the semantic view.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_values(
|
||||
self,
|
||||
dimension: Dimension,
|
||||
filters: set[Filter] | None = None,
|
||||
) -> SemanticResult:
|
||||
"""
|
||||
Return distinct values for a dimension.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_table(self, query: SemanticQuery) -> SemanticResult:
|
||||
"""
|
||||
Execute a semantic query and return the results.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_row_count(self, query: SemanticQuery) -> SemanticResult:
|
||||
"""
|
||||
Execute a query and return the number of rows the result would have.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_compatible_metrics(
|
||||
self,
|
||||
selected_metrics: set[Metric],
|
||||
selected_dimensions: set[Dimension],
|
||||
) -> set[Metric]:
|
||||
"""
|
||||
Return metrics compatible with the selected dimensions.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_compatible_dimensions(
|
||||
self,
|
||||
selected_metrics: set[Metric],
|
||||
selected_dimensions: set[Dimension],
|
||||
) -> set[Dimension]:
|
||||
"""
|
||||
Return dimensions compatible with the selected metrics.
|
||||
"""
|
||||
@@ -23,7 +23,7 @@ import { Label } from '..';
|
||||
|
||||
// Define the prop types for DatasetTypeLabel
|
||||
interface DatasetTypeLabelProps {
|
||||
datasetType: 'physical' | 'virtual'; // Accepts only 'physical' or 'virtual'
|
||||
datasetType: 'physical' | 'virtual' | 'semantic_view';
|
||||
}
|
||||
|
||||
const SIZE = 's'; // Define the size as a constant
|
||||
@@ -32,6 +32,24 @@ export const DatasetTypeLabel: React.FC<DatasetTypeLabelProps> = ({
|
||||
datasetType,
|
||||
}) => {
|
||||
const theme = useTheme();
|
||||
|
||||
if (datasetType === 'semantic_view') {
|
||||
return (
|
||||
<Label
|
||||
icon={
|
||||
<Icons.ApartmentOutlined
|
||||
iconSize={SIZE}
|
||||
iconColor={theme.colorInfo}
|
||||
/>
|
||||
}
|
||||
type="info"
|
||||
style={{ color: theme.colorInfo }}
|
||||
>
|
||||
{t('Semantic')}
|
||||
</Label>
|
||||
);
|
||||
}
|
||||
|
||||
const label: string =
|
||||
datasetType === 'physical' ? t('Physical') : t('Virtual');
|
||||
const icon =
|
||||
|
||||
@@ -19,6 +19,15 @@
|
||||
|
||||
import { DatasourceType } from './types/Datasource';
|
||||
|
||||
const DATASOURCE_TYPE_MAP: Record<string, DatasourceType> = {
|
||||
table: DatasourceType.Table,
|
||||
query: DatasourceType.Query,
|
||||
dataset: DatasourceType.Dataset,
|
||||
sl_table: DatasourceType.SlTable,
|
||||
saved_query: DatasourceType.SavedQuery,
|
||||
semantic_view: DatasourceType.SemanticView,
|
||||
};
|
||||
|
||||
export default class DatasourceKey {
|
||||
readonly id: number;
|
||||
|
||||
@@ -27,8 +36,7 @@ export default class DatasourceKey {
|
||||
constructor(key: string) {
|
||||
const [idStr, typeStr] = key.split('__');
|
||||
this.id = parseInt(idStr, 10);
|
||||
this.type = DatasourceType.Table; // default to SqlaTable model
|
||||
this.type = typeStr === 'query' ? DatasourceType.Query : this.type;
|
||||
this.type = DATASOURCE_TYPE_MAP[typeStr] ?? DatasourceType.Table;
|
||||
}
|
||||
|
||||
public toString() {
|
||||
|
||||
@@ -26,6 +26,7 @@ export enum DatasourceType {
|
||||
Dataset = 'dataset',
|
||||
SlTable = 'sl_table',
|
||||
SavedQuery = 'saved_query',
|
||||
SemanticView = 'semantic_view',
|
||||
}
|
||||
|
||||
export interface Currency {
|
||||
|
||||
@@ -60,6 +60,7 @@ export enum FeatureFlag {
|
||||
ListviewsDefaultCardView = 'LISTVIEWS_DEFAULT_CARD_VIEW',
|
||||
Matrixify = 'MATRIXIFY',
|
||||
ScheduledQueries = 'SCHEDULED_QUERIES',
|
||||
SemanticLayers = 'SEMANTIC_LAYERS',
|
||||
SqllabBackendPersistence = 'SQLLAB_BACKEND_PERSISTENCE',
|
||||
SqlValidatorsByEngine = 'SQL_VALIDATORS_BY_ENGINE',
|
||||
SshTunneling = 'SSH_TUNNELING',
|
||||
|
||||
@@ -28,10 +28,11 @@ test('DEFAULT_METRICS', () => {
|
||||
});
|
||||
|
||||
test('DatasourceType', () => {
|
||||
expect(Object.keys(DatasourceType).length).toBe(5);
|
||||
expect(Object.keys(DatasourceType).length).toBe(6);
|
||||
expect(DatasourceType.Table).toBe('table');
|
||||
expect(DatasourceType.Query).toBe('query');
|
||||
expect(DatasourceType.Dataset).toBe('dataset');
|
||||
expect(DatasourceType.SlTable).toBe('sl_table');
|
||||
expect(DatasourceType.SavedQuery).toBe('saved_query');
|
||||
expect(DatasourceType.SemanticView).toBe('semantic_view');
|
||||
});
|
||||
|
||||
@@ -151,11 +151,8 @@ export const getSlicePayload = async (
|
||||
const [id, typeString] = formData.datasource.split('__');
|
||||
datasourceId = parseInt(id, 10);
|
||||
|
||||
const formattedTypeString =
|
||||
typeString.charAt(0).toUpperCase() + typeString.slice(1);
|
||||
if (formattedTypeString in DatasourceType) {
|
||||
datasourceType =
|
||||
DatasourceType[formattedTypeString as keyof typeof DatasourceType];
|
||||
if (Object.values(DatasourceType).includes(typeString as DatasourceType)) {
|
||||
datasourceType = typeString as DatasourceType;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import userEvent from '@testing-library/user-event';
|
||||
import { render, screen, waitFor } from 'spec/helpers/testing-library';
|
||||
import { SupersetClient, getClientErrorObject } from '@superset-ui/core';
|
||||
|
||||
import SemanticViewEditModal from './SemanticViewEditModal';
|
||||
|
||||
jest.mock('@superset-ui/core', () => ({
|
||||
...jest.requireActual('@superset-ui/core'),
|
||||
SupersetClient: {
|
||||
...jest.requireActual('@superset-ui/core').SupersetClient,
|
||||
put: jest.fn(),
|
||||
},
|
||||
getClientErrorObject: jest.fn(() => Promise.resolve({ error: '' })),
|
||||
}));
|
||||
|
||||
const mockedPut = SupersetClient.put as jest.Mock;
|
||||
const mockedGetClientErrorObject = getClientErrorObject as jest.Mock;
|
||||
|
||||
const createProps = () => ({
|
||||
show: true,
|
||||
onHide: jest.fn(),
|
||||
onSave: jest.fn(),
|
||||
addDangerToast: jest.fn(),
|
||||
addSuccessToast: jest.fn(),
|
||||
semanticView: {
|
||||
id: 7,
|
||||
table_name: 'orders_semantic_view',
|
||||
description: 'old description',
|
||||
cache_timeout: 60,
|
||||
},
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
mockedPut.mockReset();
|
||||
mockedGetClientErrorObject.mockReset();
|
||||
mockedGetClientErrorObject.mockResolvedValue({ error: '' });
|
||||
});
|
||||
|
||||
test('saves semantic view and refreshes list', async () => {
|
||||
mockedPut.mockResolvedValue({});
|
||||
const props = createProps();
|
||||
|
||||
render(<SemanticViewEditModal {...props} />);
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: /save/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedPut).toHaveBeenCalledWith({
|
||||
endpoint: '/api/v1/semantic_view/7',
|
||||
jsonPayload: {
|
||||
description: 'old description',
|
||||
cache_timeout: 60,
|
||||
},
|
||||
});
|
||||
});
|
||||
expect(props.addSuccessToast).toHaveBeenCalledWith('Semantic view updated');
|
||||
expect(props.onSave).toHaveBeenCalled();
|
||||
expect(props.onHide).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('shows backend error toast when save fails', async () => {
|
||||
mockedPut.mockRejectedValue(new Error('save failed'));
|
||||
mockedGetClientErrorObject.mockResolvedValue({
|
||||
error: 'Semantic view failed to save',
|
||||
});
|
||||
const props = createProps();
|
||||
|
||||
render(<SemanticViewEditModal {...props} />);
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: /save/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(props.addDangerToast).toHaveBeenCalledWith(
|
||||
'Semantic view failed to save',
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,119 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import { useState, useEffect } from 'react';
|
||||
import { t } from '@apache-superset/core/translation';
|
||||
import { SupersetClient, getClientErrorObject } from '@superset-ui/core';
|
||||
import { Input, InputNumber } from '@superset-ui/core/components';
|
||||
import { Icons } from '@superset-ui/core/components/Icons';
|
||||
import {
|
||||
StandardModal,
|
||||
ModalFormField,
|
||||
MODAL_STANDARD_WIDTH,
|
||||
} from 'src/components/Modal';
|
||||
|
||||
type InputNumberValue = number | null;
|
||||
|
||||
interface SemanticViewEditModalProps {
|
||||
show: boolean;
|
||||
onHide: () => void;
|
||||
onSave: () => void;
|
||||
addDangerToast: (msg: string) => void;
|
||||
addSuccessToast: (msg: string) => void;
|
||||
semanticView: {
|
||||
id: number;
|
||||
table_name: string;
|
||||
description?: string | null;
|
||||
cache_timeout?: number | null;
|
||||
} | null;
|
||||
}
|
||||
|
||||
export default function SemanticViewEditModal({
|
||||
show,
|
||||
onHide,
|
||||
onSave,
|
||||
addDangerToast,
|
||||
addSuccessToast,
|
||||
semanticView,
|
||||
}: SemanticViewEditModalProps) {
|
||||
const [description, setDescription] = useState<string>('');
|
||||
const [cacheTimeout, setCacheTimeout] = useState<number | null>(null);
|
||||
const [saving, setSaving] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (semanticView) {
|
||||
setDescription(semanticView.description || '');
|
||||
setCacheTimeout(semanticView.cache_timeout ?? null);
|
||||
}
|
||||
}, [semanticView]);
|
||||
|
||||
const handleSave = async () => {
|
||||
if (!semanticView) return;
|
||||
setSaving(true);
|
||||
try {
|
||||
await SupersetClient.put({
|
||||
endpoint: `/api/v1/semantic_view/${semanticView.id}`,
|
||||
jsonPayload: {
|
||||
description: description || null,
|
||||
cache_timeout: cacheTimeout,
|
||||
},
|
||||
});
|
||||
addSuccessToast(t('Semantic view updated'));
|
||||
onSave();
|
||||
onHide();
|
||||
} catch (error) {
|
||||
const clientError = await getClientErrorObject(error);
|
||||
addDangerToast(
|
||||
clientError.error ||
|
||||
t('An error occurred while saving the semantic view'),
|
||||
);
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<StandardModal
|
||||
show={show}
|
||||
onHide={onHide}
|
||||
onSave={handleSave}
|
||||
title={t('Edit %s', semanticView?.table_name || '')}
|
||||
icon={<Icons.EditOutlined />}
|
||||
isEditMode
|
||||
width={MODAL_STANDARD_WIDTH}
|
||||
saveLoading={saving}
|
||||
>
|
||||
<ModalFormField label={t('Description')}>
|
||||
<Input.TextArea
|
||||
value={description}
|
||||
onChange={e => setDescription(e.target.value)}
|
||||
rows={4}
|
||||
/>
|
||||
</ModalFormField>
|
||||
<ModalFormField label={t('Cache timeout')}>
|
||||
<InputNumber
|
||||
value={cacheTimeout}
|
||||
onChange={value => setCacheTimeout(value as InputNumberValue)}
|
||||
min={0}
|
||||
placeholder={t('Duration in seconds')}
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
</ModalFormField>
|
||||
</StandardModal>
|
||||
);
|
||||
}
|
||||
@@ -31,6 +31,7 @@ import {
|
||||
mockRelatedCharts,
|
||||
mockRelatedDashboards,
|
||||
mockHandleResourceExport,
|
||||
mockDatasetListEndpoints,
|
||||
API_ENDPOINTS,
|
||||
} from './DatasetList.testHelpers';
|
||||
|
||||
@@ -98,7 +99,7 @@ test('typing in search triggers debounced API call with search filter', async ()
|
||||
|
||||
// Record initial API calls
|
||||
const initialCallCount = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
// Type search query and submit with Enter to trigger the debounced fetch
|
||||
@@ -107,14 +108,16 @@ test('typing in search triggers debounced API call with search filter', async ()
|
||||
// Wait for debounced API call
|
||||
await waitFor(
|
||||
() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(initialCallCount);
|
||||
},
|
||||
{ timeout: 5000 },
|
||||
);
|
||||
|
||||
// Verify the latest API call includes search filter in URL
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASOURCE_COMBINED);
|
||||
const latestCall = calls[calls.length - 1];
|
||||
const { url } = latestCall;
|
||||
|
||||
@@ -136,8 +139,7 @@ test('typing in search triggers debounced API call with search filter', async ()
|
||||
test('500 error triggers danger toast with error message', async () => {
|
||||
const addDangerToast = jest.fn();
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
status: 500,
|
||||
body: { message: 'Internal Server Error' },
|
||||
});
|
||||
@@ -173,8 +175,7 @@ test('500 error triggers danger toast with error message', async () => {
|
||||
test('network timeout triggers danger toast', async () => {
|
||||
const addDangerToast = jest.fn();
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
throws: new Error('Network timeout'),
|
||||
});
|
||||
|
||||
@@ -213,8 +214,7 @@ test('clicking delete opens modal with related objects count', async () => {
|
||||
// Set up delete mocks
|
||||
setupDeleteMocks(datasetToDelete.id);
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetToDelete],
|
||||
count: 1,
|
||||
});
|
||||
@@ -254,8 +254,7 @@ test('clicking delete opens modal with related objects count', async () => {
|
||||
test('clicking export calls handleResourceExport with dataset ID', async () => {
|
||||
const datasetToExport = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetToExport],
|
||||
count: 1,
|
||||
});
|
||||
@@ -288,8 +287,7 @@ test('clicking duplicate opens modal and submits duplicate request', async () =>
|
||||
kind: 'virtual',
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetToDuplicate],
|
||||
count: 1,
|
||||
});
|
||||
@@ -312,7 +310,7 @@ test('clicking duplicate opens modal and submits duplicate request', async () =>
|
||||
|
||||
// Track initial dataset list API calls BEFORE duplicate action
|
||||
const initialDatasetCallCount = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
const row = screen.getByText(datasetToDuplicate.table_name).closest('tr');
|
||||
@@ -355,7 +353,9 @@ test('clicking duplicate opens modal and submits duplicate request', async () =>
|
||||
// Verify refreshData() is called (observable via new dataset list API call)
|
||||
await waitFor(
|
||||
() => {
|
||||
const datasetCalls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const datasetCalls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(datasetCalls.length).toBeGreaterThan(initialDatasetCallCount);
|
||||
},
|
||||
{ timeout: 3000 },
|
||||
@@ -376,8 +376,7 @@ test('certified dataset shows badge and tooltip with certification details', asy
|
||||
}),
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [certifiedDataset],
|
||||
count: 1,
|
||||
});
|
||||
@@ -417,8 +416,7 @@ test('dataset with warning shows icon and tooltip with markdown content', async
|
||||
}),
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetWithWarning],
|
||||
count: 1,
|
||||
});
|
||||
@@ -452,8 +450,7 @@ test('dataset with warning shows icon and tooltip with markdown content', async
|
||||
test('dataset name links to Explore with correct URL and accessible label', async () => {
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ import {
|
||||
mockAdminUser,
|
||||
mockDatasets,
|
||||
setupBulkDeleteMocks,
|
||||
mockDatasetListEndpoints,
|
||||
API_ENDPOINTS,
|
||||
} from './DatasetList.testHelpers';
|
||||
|
||||
@@ -72,8 +73,7 @@ test('ListView provider correctly merges filter + sort + pagination state on ref
|
||||
// the ListView provider correctly merges them for the API call.
|
||||
// Component tests verify individual pieces persist; this verifies they COMBINE correctly.
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: mockDatasets,
|
||||
count: mockDatasets.length,
|
||||
});
|
||||
@@ -91,31 +91,33 @@ test('ListView provider correctly merges filter + sort + pagination state on ref
|
||||
});
|
||||
|
||||
const callsBeforeSort = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
await userEvent.click(nameHeader);
|
||||
|
||||
// Wait for sort-triggered refetch to complete before applying filter
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS).length,
|
||||
fetchMock.callHistory.calls(API_ENDPOINTS.DATASOURCE_COMBINED).length,
|
||||
).toBeGreaterThan(callsBeforeSort);
|
||||
});
|
||||
|
||||
// 2. Apply a filter using selectOption helper
|
||||
const beforeFilterCallCount = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
await selectOption('Virtual', 'Type');
|
||||
|
||||
// Wait for filter API call to complete
|
||||
await waitFor(() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(beforeFilterCallCount);
|
||||
});
|
||||
|
||||
// 3. Verify the final API call contains ALL three state pieces merged correctly
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASOURCE_COMBINED);
|
||||
const latestCall = calls[calls.length - 1];
|
||||
const { url } = latestCall;
|
||||
|
||||
@@ -151,8 +153,7 @@ test('bulk action orchestration: selection → action → cleanup cycle works co
|
||||
|
||||
setupBulkDeleteMocks();
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: mockDatasets,
|
||||
count: mockDatasets.length,
|
||||
});
|
||||
@@ -218,7 +219,7 @@ test('bulk action orchestration: selection → action → cleanup cycle works co
|
||||
|
||||
// Capture datasets call count before confirming
|
||||
const datasetsCallCountBeforeDelete = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
const confirmButton = within(modal)
|
||||
@@ -242,7 +243,7 @@ test('bulk action orchestration: selection → action → cleanup cycle works co
|
||||
// Wait for datasets refetch after delete
|
||||
await waitFor(() => {
|
||||
const datasetsCallCount = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
expect(datasetsCallCount).toBeGreaterThan(datasetsCallCountBeforeDelete);
|
||||
});
|
||||
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
mockHandleResourceExport,
|
||||
assertOnlyExpectedCalls,
|
||||
API_ENDPOINTS,
|
||||
mockDatasetListEndpoints,
|
||||
getDeleteRouteName,
|
||||
} from './DatasetList.testHelpers';
|
||||
|
||||
@@ -113,8 +114,7 @@ const setupErrorTestScenario = ({
|
||||
});
|
||||
|
||||
// Configure fetchMock to return single dataset
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
// Render component with toast mocks
|
||||
renderDatasetList(mockAdminUser, {
|
||||
@@ -157,7 +157,7 @@ test('required API endpoints are called and no unmocked calls on initial render'
|
||||
// assertOnlyExpectedCalls checks: 1) no unmatched calls, 2) each expected endpoint was called
|
||||
assertOnlyExpectedCalls([
|
||||
API_ENDPOINTS.DATASETS_INFO, // Permission check
|
||||
API_ENDPOINTS.DATASETS, // Main dataset list data
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED, // Main dataset list data
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -197,8 +197,7 @@ test('renders all required column headers', async () => {
|
||||
test('displays dataset name in Name column', async () => {
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -211,8 +210,7 @@ test('displays dataset type as Physical or Virtual', async () => {
|
||||
const physicalDataset = mockDatasets[0]; // kind: 'physical'
|
||||
const virtualDataset = mockDatasets[1]; // kind: 'virtual'
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [physicalDataset, virtualDataset],
|
||||
count: 2,
|
||||
});
|
||||
@@ -229,8 +227,7 @@ test('displays dataset type as Physical or Virtual', async () => {
|
||||
test('displays database name in Database column', async () => {
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -244,8 +241,7 @@ test('displays database name in Database column', async () => {
|
||||
test('displays schema name in Schema column', async () => {
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -257,8 +253,7 @@ test('displays schema name in Schema column', async () => {
|
||||
test('displays last modified date in humanized format', async () => {
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -283,7 +278,7 @@ test('sorting by Name column updates API call with sort parameter', async () =>
|
||||
|
||||
// Record initial calls
|
||||
const initialCalls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
// Click Name header to sort
|
||||
@@ -291,12 +286,14 @@ test('sorting by Name column updates API call with sort parameter', async () =>
|
||||
|
||||
// Wait for new API call
|
||||
await waitFor(() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(initialCalls);
|
||||
});
|
||||
|
||||
// Verify latest call includes sort parameter
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASOURCE_COMBINED);
|
||||
const latestCall = calls[calls.length - 1];
|
||||
const { url } = latestCall;
|
||||
|
||||
@@ -317,17 +314,19 @@ test('sorting by Database column updates sort parameter', async () => {
|
||||
});
|
||||
|
||||
const initialCalls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
await userEvent.click(databaseHeader);
|
||||
|
||||
await waitFor(() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(initialCalls);
|
||||
});
|
||||
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASOURCE_COMBINED);
|
||||
const { url } = calls[calls.length - 1];
|
||||
expect(url).toMatch(/order_column|sort/);
|
||||
});
|
||||
@@ -345,17 +344,19 @@ test('sorting by Last modified column updates sort parameter', async () => {
|
||||
});
|
||||
|
||||
const initialCalls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
await userEvent.click(modifiedHeader);
|
||||
|
||||
await waitFor(() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(initialCalls);
|
||||
});
|
||||
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASOURCE_COMBINED);
|
||||
const { url } = calls[calls.length - 1];
|
||||
expect(url).toMatch(/order_column|sort/);
|
||||
});
|
||||
@@ -363,8 +364,7 @@ test('sorting by Last modified column updates sort parameter', async () => {
|
||||
test('export button triggers handleResourceExport with dataset ID', async () => {
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -392,8 +392,7 @@ test('delete button opens modal with dataset details', async () => {
|
||||
|
||||
setupDeleteMocks(dataset.id);
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -415,8 +414,7 @@ test('delete action successfully deletes dataset and refreshes list', async () =
|
||||
const datasetToDelete = mockDatasets[0];
|
||||
setupDeleteMocks(datasetToDelete.id);
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetToDelete],
|
||||
count: 1,
|
||||
});
|
||||
@@ -442,7 +440,7 @@ test('delete action successfully deletes dataset and refreshes list', async () =
|
||||
|
||||
// Track API calls before confirm
|
||||
const callsBefore = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
// Click confirm - find the danger button (last delete button in modal)
|
||||
@@ -468,7 +466,7 @@ test('delete action successfully deletes dataset and refreshes list', async () =
|
||||
// List refreshes
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS).length,
|
||||
fetchMock.callHistory.calls(API_ENDPOINTS.DATASOURCE_COMBINED).length,
|
||||
).toBeGreaterThan(callsBefore);
|
||||
});
|
||||
});
|
||||
@@ -477,8 +475,7 @@ test('delete action cancel closes modal without deleting', async () => {
|
||||
const dataset = mockDatasets[0];
|
||||
setupDeleteMocks(dataset.id);
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -518,8 +515,7 @@ test('duplicate action successfully duplicates virtual dataset', async () => {
|
||||
const virtualDataset = mockDatasets[1]; // Virtual dataset (kind: 'virtual')
|
||||
setupDuplicateMocks();
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [virtualDataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [virtualDataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser, {
|
||||
addSuccessToast: mockAddSuccessToast,
|
||||
@@ -542,7 +538,7 @@ test('duplicate action successfully duplicates virtual dataset', async () => {
|
||||
|
||||
// Track API calls before submit
|
||||
const callsBefore = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
// Submit
|
||||
@@ -564,7 +560,7 @@ test('duplicate action successfully duplicates virtual dataset', async () => {
|
||||
// List refreshes
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS).length,
|
||||
fetchMock.callHistory.calls(API_ENDPOINTS.DATASOURCE_COMBINED).length,
|
||||
).toBeGreaterThan(callsBefore);
|
||||
});
|
||||
});
|
||||
@@ -573,8 +569,7 @@ test('duplicate button visible only for virtual datasets', async () => {
|
||||
const physicalDataset = mockDatasets[0]; // kind: 'physical'
|
||||
const virtualDataset = mockDatasets[1]; // kind: 'virtual'
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [physicalDataset, virtualDataset],
|
||||
count: 2,
|
||||
});
|
||||
@@ -633,8 +628,7 @@ test('bulk select enables checkboxes', async () => {
|
||||
}, 30000);
|
||||
|
||||
test('selecting all datasets shows correct count in toolbar', async () => {
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: mockDatasets,
|
||||
count: mockDatasets.length,
|
||||
});
|
||||
@@ -673,8 +667,7 @@ test('selecting all datasets shows correct count in toolbar', async () => {
|
||||
}, 30000);
|
||||
|
||||
test('bulk export triggers export with selected IDs', async () => {
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [mockDatasets[0]],
|
||||
count: 1,
|
||||
});
|
||||
@@ -716,8 +709,7 @@ test('bulk export triggers export with selected IDs', async () => {
|
||||
test('bulk delete opens confirmation modal', async () => {
|
||||
setupBulkDeleteMocks();
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [mockDatasets[0]],
|
||||
count: 1,
|
||||
});
|
||||
@@ -823,8 +815,7 @@ test('certified badge appears for certified datasets', async () => {
|
||||
}),
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [certifiedDataset],
|
||||
count: 1,
|
||||
});
|
||||
@@ -854,8 +845,7 @@ test('warning icon appears for datasets with warnings', async () => {
|
||||
}),
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetWithWarning],
|
||||
count: 1,
|
||||
});
|
||||
@@ -883,8 +873,7 @@ test('info tooltip appears for datasets with descriptions', async () => {
|
||||
description: 'Sales data from Q4 2024',
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetWithDescription],
|
||||
count: 1,
|
||||
});
|
||||
@@ -909,8 +898,7 @@ test('info tooltip appears for datasets with descriptions', async () => {
|
||||
test('dataset name links to Explore page', async () => {
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -930,8 +918,7 @@ test('dataset name links to Explore page', async () => {
|
||||
test('physical dataset shows delete, export, and edit actions (no duplicate)', async () => {
|
||||
const physicalDataset = mockDatasets[0]; // kind: 'physical'
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [physicalDataset],
|
||||
count: 1,
|
||||
});
|
||||
@@ -962,8 +949,7 @@ test('physical dataset shows delete, export, and edit actions (no duplicate)', a
|
||||
test('virtual dataset shows delete, export, edit, and duplicate actions', async () => {
|
||||
const virtualDataset = mockDatasets[1]; // kind: 'virtual'
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [virtualDataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [virtualDataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -992,8 +978,7 @@ test('edit action is enabled for dataset owner', async () => {
|
||||
owners: [{ id: mockAdminUser.userId, username: 'admin' }],
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -1016,8 +1001,7 @@ test('edit action is disabled for non-owner', async () => {
|
||||
owners: [{ id: 999, username: 'other_user' }], // Different user
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
// Use a non-admin user to test ownership check
|
||||
const regularUser = {
|
||||
@@ -1046,8 +1030,7 @@ test('all action buttons are clickable and enabled for admin user', async () =>
|
||||
owners: [{ id: mockAdminUser.userId, username: 'admin' }],
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [virtualDataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [virtualDataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -1082,8 +1065,7 @@ test('all action buttons are clickable and enabled for admin user', async () =>
|
||||
});
|
||||
|
||||
test('displays error when initial dataset fetch fails with 500', async () => {
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
status: 500,
|
||||
body: { message: 'Internal Server Error' },
|
||||
});
|
||||
@@ -1104,8 +1086,7 @@ test('displays error when initial dataset fetch fails with 500', async () => {
|
||||
});
|
||||
|
||||
test('displays error when initial dataset fetch fails with 403 permission denied', async () => {
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
status: 403,
|
||||
body: { message: 'Access Denied' },
|
||||
});
|
||||
@@ -1119,9 +1100,9 @@ test('displays error when initial dataset fetch fails with 403 permission denied
|
||||
expect(mockAddDangerToast).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// Verify toast message contains the 403-specific "Access Denied" text
|
||||
// Verify toast message contains the generic error text
|
||||
const toastMessage = String(mockAddDangerToast.mock.calls[0][0]);
|
||||
expect(toastMessage).toContain('Access Denied');
|
||||
expect(toastMessage).toContain('An error occurred while fetching datasets');
|
||||
|
||||
// No dataset names from mockDatasets should appear in the document
|
||||
mockDatasets.forEach(dataset => {
|
||||
@@ -1373,7 +1354,7 @@ test('sort order persists after deleting a dataset', async () => {
|
||||
|
||||
// Record initial API calls count
|
||||
const initialCalls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
// Click Name header to sort
|
||||
@@ -1381,12 +1362,16 @@ test('sort order persists after deleting a dataset', async () => {
|
||||
|
||||
// Wait for new API call with sort parameter
|
||||
await waitFor(() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(initialCalls);
|
||||
});
|
||||
|
||||
// Record the sort parameter from the API call after sorting
|
||||
const callsAfterSort = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const callsAfterSort = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
const sortedUrl = callsAfterSort[callsAfterSort.length - 1].url;
|
||||
expect(sortedUrl).toMatch(/order_column|sort/);
|
||||
|
||||
@@ -1406,7 +1391,7 @@ test('sort order persists after deleting a dataset', async () => {
|
||||
|
||||
// Record call count before delete to track refetch
|
||||
const callsBeforeDelete = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
const confirmButton = within(modal)
|
||||
@@ -1427,7 +1412,7 @@ test('sort order persists after deleting a dataset', async () => {
|
||||
// Wait for list refetch to complete (prevents async cleanup error)
|
||||
await waitFor(() => {
|
||||
const currentCalls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
expect(currentCalls).toBeGreaterThan(callsBeforeDelete);
|
||||
});
|
||||
@@ -1452,8 +1437,7 @@ test('sort order persists after deleting a dataset', async () => {
|
||||
// test. Component tests here focus on individual behaviors.
|
||||
|
||||
test('bulk selection clears when filter changes', async () => {
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: mockDatasets,
|
||||
count: mockDatasets.length,
|
||||
});
|
||||
@@ -1505,7 +1489,7 @@ test('bulk selection clears when filter changes', async () => {
|
||||
|
||||
// Record API call count before filter
|
||||
const beforeFilterCallCount = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
// Wait for filter combobox to be ready before applying filter
|
||||
@@ -1516,13 +1500,15 @@ test('bulk selection clears when filter changes', async () => {
|
||||
|
||||
// Wait for filter API call to complete
|
||||
await waitFor(() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(beforeFilterCallCount);
|
||||
});
|
||||
|
||||
// Verify filter was applied by decoding URL payload
|
||||
const urlAfterFilter = fetchMock.callHistory
|
||||
.calls(API_ENDPOINTS.DATASETS)
|
||||
.calls(API_ENDPOINTS.DATASOURCE_COMBINED)
|
||||
.at(-1)?.url;
|
||||
const risonAfterFilter = new URL(
|
||||
urlAfterFilter!,
|
||||
@@ -1557,7 +1543,7 @@ test('type filter API call includes correct filter parameter', async () => {
|
||||
|
||||
// Snapshot call count before filter
|
||||
const callsBeforeFilter = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
// Apply Type filter
|
||||
@@ -1565,12 +1551,16 @@ test('type filter API call includes correct filter parameter', async () => {
|
||||
|
||||
// Wait for filter API call to complete
|
||||
await waitFor(() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(callsBeforeFilter);
|
||||
});
|
||||
|
||||
// Verify the latest API call includes the Type filter
|
||||
const url = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS).at(-1)?.url;
|
||||
const url = fetchMock.callHistory
|
||||
.calls(API_ENDPOINTS.DATASOURCE_COMBINED)
|
||||
.at(-1)?.url;
|
||||
expect(url).toContain('filters');
|
||||
|
||||
// searchParams.get() already URL-decodes, so pass directly to rison.decode
|
||||
@@ -1603,7 +1593,7 @@ test('type filter persists after duplicating a dataset', async () => {
|
||||
|
||||
// Snapshot call count before filter
|
||||
const callsBeforeFilter = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
// Apply Type filter
|
||||
@@ -1611,13 +1601,15 @@ test('type filter persists after duplicating a dataset', async () => {
|
||||
|
||||
// Wait for filter API call to complete
|
||||
await waitFor(() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(callsBeforeFilter);
|
||||
});
|
||||
|
||||
// Verify filter is present by checking the latest API call
|
||||
const urlAfterFilter = fetchMock.callHistory
|
||||
.calls(API_ENDPOINTS.DATASETS)
|
||||
.calls(API_ENDPOINTS.DATASOURCE_COMBINED)
|
||||
.at(-1)?.url;
|
||||
const risonAfterFilter = new URL(
|
||||
urlAfterFilter!,
|
||||
@@ -1637,7 +1629,7 @@ test('type filter persists after duplicating a dataset', async () => {
|
||||
|
||||
// Capture datasets API call count BEFORE any duplicate operations
|
||||
const datasetsCallCountBeforeDuplicate = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
// Now duplicate the dataset
|
||||
@@ -1673,14 +1665,14 @@ test('type filter persists after duplicating a dataset', async () => {
|
||||
// Wait for datasets refetch to occur (proves duplicate triggered a refresh)
|
||||
await waitFor(() => {
|
||||
const datasetsCallCount = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
expect(datasetsCallCount).toBeGreaterThan(datasetsCallCountBeforeDuplicate);
|
||||
});
|
||||
|
||||
// Verify Type filter persisted in the NEW datasets API call after duplication
|
||||
const urlAfterDuplicate = fetchMock.callHistory
|
||||
.calls(API_ENDPOINTS.DATASETS)
|
||||
.calls(API_ENDPOINTS.DATASOURCE_COMBINED)
|
||||
.at(-1)?.url;
|
||||
const risonAfterDuplicate = new URL(
|
||||
urlAfterDuplicate!,
|
||||
@@ -1715,8 +1707,7 @@ test('edit action shows error toast when dataset fetch fails', async () => {
|
||||
],
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [ownedDataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [ownedDataset], count: 1 });
|
||||
|
||||
// Mock SupersetClient.get to fail for the specific dataset endpoint
|
||||
jest.spyOn(SupersetClient, 'get').mockImplementation(async request => {
|
||||
@@ -1759,8 +1750,7 @@ test('bulk export error shows toast and clears loading state', async () => {
|
||||
// Mock handleResourceExport to throw an error
|
||||
mockHandleResourceExport.mockRejectedValueOnce(new Error('Export failed'));
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [mockDatasets[0]],
|
||||
count: 1,
|
||||
});
|
||||
@@ -1824,8 +1814,7 @@ test('bulk delete error shows toast without refreshing list', async () => {
|
||||
body: { message: 'Bulk delete failed' },
|
||||
});
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [mockDatasets[0]],
|
||||
count: 1,
|
||||
});
|
||||
@@ -1901,8 +1890,7 @@ test('bulk select shows "N Selected (Virtual)" for virtual-only selection', asyn
|
||||
// Use only virtual datasets
|
||||
const virtualDatasets = mockDatasets.filter(d => d.kind === 'virtual');
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: virtualDatasets,
|
||||
count: virtualDatasets.length,
|
||||
});
|
||||
@@ -1948,8 +1936,7 @@ test('bulk select shows "N Selected (Physical)" for physical-only selection', as
|
||||
// Use only physical datasets
|
||||
const physicalDatasets = mockDatasets.filter(d => d.kind === 'physical');
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: physicalDatasets,
|
||||
count: physicalDatasets.length,
|
||||
});
|
||||
@@ -1999,8 +1986,7 @@ test('bulk select shows mixed count for virtual and physical selection', async (
|
||||
mockDatasets.find(d => d.kind === 'virtual')!,
|
||||
];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: mixedDatasets,
|
||||
count: mixedDatasets.length,
|
||||
});
|
||||
@@ -2063,8 +2049,7 @@ test('delete modal shows affected dashboards with overflow for >10 items', async
|
||||
title: `Dashboard ${i + 1}`,
|
||||
}));
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
fetchMock.get(`glob:*/api/v1/dataset/${dataset.id}/related_objects*`, {
|
||||
charts: { count: 0, result: [] },
|
||||
@@ -2101,8 +2086,7 @@ test('delete modal shows affected dashboards with overflow for >10 items', async
|
||||
test('delete modal hides affected dashboards section when count is zero', async () => {
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
fetchMock.get(`glob:*/api/v1/dataset/${dataset.id}/related_objects*`, {
|
||||
charts: { count: 2, result: [{ id: 1, slice_name: 'Chart 1' }] },
|
||||
@@ -2140,8 +2124,7 @@ test('delete modal shows affected charts with overflow for >10 items', async ()
|
||||
slice_name: `Chart ${i + 1}`,
|
||||
}));
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
fetchMock.get(`glob:*/api/v1/dataset/${dataset.id}/related_objects*`, {
|
||||
charts: { count: 12, result: manyCharts },
|
||||
|
||||
@@ -27,7 +27,7 @@ import {
|
||||
mockWriteUser,
|
||||
mockExportOnlyUser,
|
||||
mockDatasets,
|
||||
API_ENDPOINTS,
|
||||
mockDatasetListEndpoints,
|
||||
} from './DatasetList.testHelpers';
|
||||
|
||||
// Increase default timeout for tests that involve multiple async operations
|
||||
@@ -238,8 +238,7 @@ test('action buttons respect user permissions', async () => {
|
||||
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -265,8 +264,7 @@ test('read-only user sees no delete or duplicate buttons in row', async () => {
|
||||
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockReadOnlyUser);
|
||||
|
||||
@@ -301,8 +299,7 @@ test('write user sees edit, delete, and export actions', async () => {
|
||||
owners: [{ id: mockWriteUser.userId, username: 'writeuser' }],
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockWriteUser);
|
||||
|
||||
@@ -337,8 +334,7 @@ test('export-only user has no Actions column (no write/duplicate permissions)',
|
||||
|
||||
const dataset = mockDatasets[0];
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockExportOnlyUser);
|
||||
|
||||
@@ -371,8 +367,7 @@ test('user with can_duplicate sees duplicate button only for virtual datasets',
|
||||
const physicalDataset = mockDatasets[0]; // kind: 'physical'
|
||||
const virtualDataset = mockDatasets[1]; // kind: 'virtual'
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [physicalDataset, virtualDataset],
|
||||
count: 2,
|
||||
});
|
||||
|
||||
@@ -29,6 +29,7 @@ import {
|
||||
mockExportOnlyUser,
|
||||
mockDatasets,
|
||||
mockApiError403,
|
||||
mockDatasetListEndpoints,
|
||||
API_ENDPOINTS,
|
||||
RisonFilter,
|
||||
} from './DatasetList.testHelpers';
|
||||
@@ -68,13 +69,17 @@ test('shows loading state during initial data fetch', () => {
|
||||
// Use fake timers to avoid leaving real timers running after test
|
||||
jest.useFakeTimers();
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
new Promise(resolve =>
|
||||
setTimeout(() => resolve({ result: [], count: 0 }), 10000),
|
||||
),
|
||||
const delayedResponse = new Promise(resolve =>
|
||||
setTimeout(() => resolve({ result: [], count: 0 }), 10000),
|
||||
);
|
||||
fetchMock.removeRoutes({
|
||||
names: [
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
],
|
||||
});
|
||||
fetchMock.get(API_ENDPOINTS.DATASOURCE_COMBINED, delayedResponse);
|
||||
fetchMock.get(API_ENDPOINTS.DATASOURCE_COMBINED, delayedResponse);
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -87,13 +92,17 @@ test('maintains component structure during loading', () => {
|
||||
// Use fake timers to avoid leaving real timers running after test
|
||||
jest.useFakeTimers();
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
new Promise(resolve =>
|
||||
setTimeout(() => resolve({ result: [], count: 0 }), 10000),
|
||||
),
|
||||
const delayedResponse = new Promise(resolve =>
|
||||
setTimeout(() => resolve({ result: [], count: 0 }), 10000),
|
||||
);
|
||||
fetchMock.removeRoutes({
|
||||
names: [
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
],
|
||||
});
|
||||
fetchMock.get(API_ENDPOINTS.DATASOURCE_COMBINED, delayedResponse);
|
||||
fetchMock.get(API_ENDPOINTS.DATASOURCE_COMBINED, delayedResponse);
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -214,8 +223,7 @@ test('handles datasets with missing fields and renders gracefully', async () =>
|
||||
sql: null,
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetWithMissingFields],
|
||||
count: 1,
|
||||
});
|
||||
@@ -241,8 +249,7 @@ test('handles datasets with missing fields and renders gracefully', async () =>
|
||||
});
|
||||
|
||||
test('handles empty results (shows empty state)', async () => {
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [], count: 0 });
|
||||
mockDatasetListEndpoints({ result: [], count: 0 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
@@ -254,7 +261,9 @@ test('makes correct initial API call on load', async () => {
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
await waitFor(() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
@@ -263,7 +272,9 @@ test('API call includes correct page size', async () => {
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
await waitFor(() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(0);
|
||||
const { url } = calls[0];
|
||||
expect(url).toContain('page_size');
|
||||
@@ -278,7 +289,7 @@ test('typing in name filter updates input value and triggers API with decoded se
|
||||
|
||||
// Record initial API calls
|
||||
const initialCallCount = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
).length;
|
||||
|
||||
// Type in search box and press Enter to trigger search
|
||||
@@ -292,7 +303,9 @@ test('typing in name filter updates input value and triggers API with decoded se
|
||||
// Wait for API call after Enter key press
|
||||
await waitFor(
|
||||
() => {
|
||||
const calls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const calls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
expect(calls.length).toBeGreaterThan(initialCallCount);
|
||||
|
||||
// Get latest API call
|
||||
@@ -346,8 +359,7 @@ test('toggling bulk select mode shows checkboxes', async () => {
|
||||
}, 30000);
|
||||
|
||||
test('handles 500 error on initial load without crashing', async () => {
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
throws: new Error('Internal Server Error'),
|
||||
});
|
||||
|
||||
@@ -385,8 +397,7 @@ test('handles 403 error on _info endpoint and disables create actions', async ()
|
||||
});
|
||||
|
||||
test('handles network timeout without crashing', async () => {
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
throws: new Error('Network timeout'),
|
||||
});
|
||||
|
||||
@@ -414,7 +425,9 @@ test('component requires explicit mocks for all API endpoints', async () => {
|
||||
await waitForDatasetsPageReady();
|
||||
|
||||
// Verify that critical endpoints were called and had mocks available
|
||||
const newDatasetsCalls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS);
|
||||
const newDatasetsCalls = fetchMock.callHistory.calls(
|
||||
API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
);
|
||||
const newInfoCalls = fetchMock.callHistory.calls(API_ENDPOINTS.DATASETS_INFO);
|
||||
|
||||
// These should have been called during render
|
||||
@@ -446,8 +459,7 @@ test('renders datasets with certification data', async () => {
|
||||
}),
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [certifiedDataset],
|
||||
count: 1,
|
||||
});
|
||||
@@ -474,8 +486,7 @@ test('displays datasets with warning_markdown', async () => {
|
||||
}),
|
||||
};
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetWithWarning],
|
||||
count: 1,
|
||||
});
|
||||
@@ -496,8 +507,7 @@ test('displays datasets with warning_markdown', async () => {
|
||||
test('displays dataset with multiple owners', async () => {
|
||||
const datasetWithOwners = mockDatasets[1]; // Has 2 owners: Jane Smith, Bob Jones
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetWithOwners],
|
||||
count: 1,
|
||||
});
|
||||
@@ -518,8 +528,7 @@ test('displays dataset with multiple owners', async () => {
|
||||
test('displays ModifiedInfo with humanized date', async () => {
|
||||
const datasetWithModified = mockDatasets[0]; // changed_by_name: 'John Doe', changed_on: '1 day ago'
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, {
|
||||
mockDatasetListEndpoints({
|
||||
result: [datasetWithModified],
|
||||
count: 1,
|
||||
});
|
||||
@@ -541,8 +550,7 @@ test('displays ModifiedInfo with humanized date', async () => {
|
||||
test('dataset name links to Explore with correct explore_url', async () => {
|
||||
const dataset = mockDatasets[0]; // explore_url: '/explore/?datasource=1__table'
|
||||
|
||||
fetchMock.removeRoutes({ names: [API_ENDPOINTS.DATASETS] });
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, { result: [dataset], count: 1 });
|
||||
mockDatasetListEndpoints({ result: [dataset], count: 1 });
|
||||
|
||||
renderDatasetList(mockAdminUser);
|
||||
|
||||
|
||||
@@ -318,6 +318,7 @@ export const mockApiError404 = {
|
||||
export const API_ENDPOINTS = {
|
||||
DATASETS_INFO: 'glob:*/api/v1/dataset/_info*',
|
||||
DATASETS: 'glob:*/api/v1/dataset/?*',
|
||||
DATASOURCE_COMBINED: 'glob:*/api/v1/datasource/?*',
|
||||
DATASET_GET: 'glob:*/api/v1/dataset/[0-9]*',
|
||||
DATASET_RELATED_OBJECTS: 'glob:*/api/v1/dataset/*/related_objects*',
|
||||
DATASET_DELETE: 'glob:*/api/v1/dataset/[0-9]*',
|
||||
@@ -499,6 +500,24 @@ export const assertOnlyExpectedCalls = (expectedEndpoints: string[]) => {
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper to mock the dataset list endpoints.
|
||||
* The component fetches from /api/v1/datasource/ (combined endpoint).
|
||||
* Some tests also need the legacy /api/v1/dataset/ endpoint for
|
||||
* other operations (delete, bulk delete) that still use it.
|
||||
*/
|
||||
export const mockDatasetListEndpoints = (response: Record<string, unknown>) => {
|
||||
fetchMock.removeRoutes({
|
||||
names: [API_ENDPOINTS.DATASETS, API_ENDPOINTS.DATASOURCE_COMBINED],
|
||||
});
|
||||
fetchMock.get(API_ENDPOINTS.DATASETS, response, {
|
||||
name: API_ENDPOINTS.DATASETS,
|
||||
});
|
||||
fetchMock.get(API_ENDPOINTS.DATASOURCE_COMBINED, response, {
|
||||
name: API_ENDPOINTS.DATASOURCE_COMBINED,
|
||||
});
|
||||
};
|
||||
|
||||
// MSW setup using fetch-mock (following ChartList pattern)
|
||||
// Routes are named using the API_ENDPOINTS constant values so they can be
|
||||
// removed by name using removeRoutes({ names: [API_ENDPOINTS.X] })
|
||||
@@ -511,11 +530,10 @@ export const setupMocks = () => {
|
||||
{ name: API_ENDPOINTS.DATASETS_INFO },
|
||||
);
|
||||
|
||||
fetchMock.get(
|
||||
API_ENDPOINTS.DATASETS,
|
||||
{ result: mockDatasets, count: mockDatasets.length },
|
||||
{ name: API_ENDPOINTS.DATASETS },
|
||||
);
|
||||
mockDatasetListEndpoints({
|
||||
result: mockDatasets,
|
||||
count: mockDatasets.length,
|
||||
});
|
||||
|
||||
fetchMock.get(
|
||||
API_ENDPOINTS.DATASET_FAVORITE_STATUS,
|
||||
|
||||
@@ -17,9 +17,15 @@
|
||||
* under the License.
|
||||
*/
|
||||
import { t } from '@apache-superset/core/translation';
|
||||
import { getExtensionsRegistry, SupersetClient } from '@superset-ui/core';
|
||||
import {
|
||||
getExtensionsRegistry,
|
||||
SupersetClient,
|
||||
isFeatureEnabled,
|
||||
FeatureFlag,
|
||||
} from '@superset-ui/core';
|
||||
import { styled, useTheme, css } from '@apache-superset/core/theme';
|
||||
import { FunctionComponent, useState, useMemo, useCallback, Key } from 'react';
|
||||
import type { CellProps } from 'react-table';
|
||||
import { Link, useHistory } from 'react-router-dom';
|
||||
import rison from 'rison';
|
||||
import {
|
||||
@@ -41,8 +47,9 @@ import {
|
||||
Loading,
|
||||
List,
|
||||
} from '@superset-ui/core/components';
|
||||
import { DatasourceModal, GenericLink } from 'src/components';
|
||||
import {
|
||||
DatasourceModal,
|
||||
GenericLink,
|
||||
FacePile,
|
||||
ImportModal as ImportModelsModal,
|
||||
ModifiedInfo,
|
||||
@@ -50,6 +57,7 @@ import {
|
||||
ListViewFilterOperator as FilterOperator,
|
||||
type ListViewProps,
|
||||
type ListViewFilters,
|
||||
type ListViewFetchDataConfig,
|
||||
} from 'src/components';
|
||||
import { Typography } from '@superset-ui/core/components/Typography';
|
||||
import handleResourceExport from 'src/utils/export';
|
||||
@@ -67,9 +75,12 @@ import {
|
||||
CONFIRM_OVERWRITE_MESSAGE,
|
||||
} from 'src/features/datasets/constants';
|
||||
import DuplicateDatasetModal from 'src/features/datasets/DuplicateDatasetModal';
|
||||
import type DatasetType from 'src/types/Dataset';
|
||||
import SemanticViewEditModal from 'src/features/semanticViews/SemanticViewEditModal';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { QueryObjectColumns } from 'src/views/CRUD/types';
|
||||
import { WIDER_DROPDOWN_WIDTH } from 'src/components/ListView/utils';
|
||||
import type { BootstrapData } from 'src/types/bootstrapTypes';
|
||||
|
||||
const extensionsRegistry = getExtensionsRegistry();
|
||||
const DatasetDeleteRelatedExtension = extensionsRegistry.get(
|
||||
@@ -115,22 +126,28 @@ const Actions = styled.div`
|
||||
|
||||
type Dataset = {
|
||||
changed_by_name: string;
|
||||
changed_by: string;
|
||||
changed_by: Owner;
|
||||
changed_on_delta_humanized: string;
|
||||
database: {
|
||||
id: string;
|
||||
database_name: string;
|
||||
};
|
||||
kind: string;
|
||||
} | null;
|
||||
kind: 'physical' | 'virtual' | 'semantic_view';
|
||||
source_type?: 'database' | 'semantic_layer';
|
||||
explore_url: string;
|
||||
id: number;
|
||||
owners: Array<Owner>;
|
||||
schema: string;
|
||||
schema: string | null;
|
||||
table_name: string;
|
||||
description?: string | null;
|
||||
cache_timeout?: number | null;
|
||||
extra?: string | Record<string, any> | null;
|
||||
sql?: string | null;
|
||||
};
|
||||
|
||||
interface VirtualDataset extends Dataset {
|
||||
extra: Record<string, any>;
|
||||
kind: 'virtual';
|
||||
extra: string | Record<string, any>;
|
||||
sql: string;
|
||||
}
|
||||
|
||||
@@ -152,18 +169,86 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
const history = useHistory();
|
||||
const theme = useTheme();
|
||||
const {
|
||||
state: {
|
||||
loading,
|
||||
resourceCount: datasetCount,
|
||||
resourceCollection: datasets,
|
||||
bulkSelectEnabled,
|
||||
},
|
||||
state: { bulkSelectEnabled },
|
||||
hasPerm,
|
||||
fetchData,
|
||||
toggleBulkSelect,
|
||||
refreshData,
|
||||
} = useListViewResource<Dataset>('dataset', t('dataset'), addDangerToast);
|
||||
|
||||
// Combined endpoint state
|
||||
const [datasets, setDatasets] = useState<Dataset[]>([]);
|
||||
const [datasetCount, setDatasetCount] = useState(0);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [lastFetchConfig, setLastFetchConfig] =
|
||||
useState<ListViewFetchDataConfig | null>(null);
|
||||
|
||||
const fetchData = useCallback(
|
||||
(config: ListViewFetchDataConfig) => {
|
||||
setLastFetchConfig(config);
|
||||
setLoading(true);
|
||||
const { pageIndex, pageSize, sortBy, filters: filterValues } = config;
|
||||
|
||||
// Separate source_type filter from other filters
|
||||
const sourceTypeFilter = filterValues.find(f => f.id === 'source_type');
|
||||
|
||||
const otherFilters = filterValues
|
||||
.filter(f => f.id !== 'source_type')
|
||||
.filter(
|
||||
({ value }) => value !== '' && value !== null && value !== undefined,
|
||||
)
|
||||
.map(({ id, operator: opr, value }) => ({
|
||||
col: id,
|
||||
opr,
|
||||
value:
|
||||
value && typeof value === 'object' && 'value' in value
|
||||
? value.value
|
||||
: value,
|
||||
}));
|
||||
|
||||
// Add source_type filter for the combined endpoint
|
||||
const sourceTypeValue =
|
||||
sourceTypeFilter?.value && typeof sourceTypeFilter.value === 'object'
|
||||
? (sourceTypeFilter.value as { value: string }).value
|
||||
: (sourceTypeFilter?.value as string | undefined);
|
||||
if (sourceTypeValue) {
|
||||
otherFilters.push({
|
||||
col: 'source_type',
|
||||
opr: 'eq',
|
||||
value: sourceTypeValue,
|
||||
});
|
||||
}
|
||||
|
||||
const queryParams = rison.encode_uri({
|
||||
order_column: sortBy[0].id,
|
||||
order_direction: sortBy[0].desc ? 'desc' : 'asc',
|
||||
page: pageIndex,
|
||||
page_size: pageSize,
|
||||
...(otherFilters.length ? { filters: otherFilters } : {}),
|
||||
});
|
||||
|
||||
return SupersetClient.get({
|
||||
endpoint: `/api/v1/datasource/?q=${queryParams}`,
|
||||
})
|
||||
.then(({ json = {} }) => {
|
||||
setDatasets(json.result);
|
||||
setDatasetCount(json.count);
|
||||
})
|
||||
.catch(() => {
|
||||
addDangerToast(t('An error occurred while fetching datasets'));
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false);
|
||||
});
|
||||
},
|
||||
[addDangerToast],
|
||||
);
|
||||
|
||||
const refreshData = useCallback(() => {
|
||||
if (lastFetchConfig) {
|
||||
return fetchData(lastFetchConfig);
|
||||
}
|
||||
return undefined;
|
||||
}, [lastFetchConfig, fetchData]);
|
||||
|
||||
const [datasetCurrentlyDeleting, setDatasetCurrentlyDeleting] = useState<
|
||||
| (Dataset & {
|
||||
charts: any;
|
||||
@@ -178,6 +263,10 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
const [datasetCurrentlyDuplicating, setDatasetCurrentlyDuplicating] =
|
||||
useState<VirtualDataset | null>(null);
|
||||
|
||||
const [svCurrentlyEditing, setSvCurrentlyEditing] = useState<Dataset | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
const [importingDataset, showImportModal] = useState<boolean>(false);
|
||||
const [passwordFields, setPasswordFields] = useState<string[]>([]);
|
||||
const [preparingExport, setPreparingExport] = useState<boolean>(false);
|
||||
@@ -192,11 +281,28 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
setSSHTunnelPrivateKeyPasswordFields,
|
||||
] = useState<string[]>([]);
|
||||
|
||||
const PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET = useSelector<any, boolean>(
|
||||
const PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET = useSelector<
|
||||
BootstrapData,
|
||||
boolean
|
||||
>(
|
||||
state =>
|
||||
state.common?.conf?.PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET || false,
|
||||
);
|
||||
|
||||
const currentSourceFilter = useMemo(() => {
|
||||
const sourceTypeFilter = lastFetchConfig?.filters.find(
|
||||
filter => filter.id === 'source_type',
|
||||
);
|
||||
if (
|
||||
sourceTypeFilter?.value &&
|
||||
typeof sourceTypeFilter.value === 'object' &&
|
||||
'value' in sourceTypeFilter.value
|
||||
) {
|
||||
return sourceTypeFilter.value.value as string;
|
||||
}
|
||||
return (sourceTypeFilter?.value as string | undefined) ?? '';
|
||||
}, [lastFetchConfig]);
|
||||
|
||||
const openDatasetImportModal = () => {
|
||||
showImportModal(true);
|
||||
};
|
||||
@@ -288,7 +394,7 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
await handleResourceExport('dataset', ids, () => {
|
||||
setPreparingExport(false);
|
||||
});
|
||||
} catch (error) {
|
||||
} catch {
|
||||
setPreparingExport(false);
|
||||
addDangerToast(t('There was an issue exporting the selected datasets'));
|
||||
}
|
||||
@@ -315,7 +421,7 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
explore_url: exploreURL,
|
||||
},
|
||||
},
|
||||
}: any) => {
|
||||
}: CellProps<Dataset>) => {
|
||||
let titleLink: JSX.Element;
|
||||
if (PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET) {
|
||||
titleLink = (
|
||||
@@ -331,7 +437,10 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
);
|
||||
}
|
||||
try {
|
||||
const parsedExtra = JSON.parse(extra);
|
||||
const parsedExtra =
|
||||
typeof extra === 'string'
|
||||
? JSON.parse(extra)
|
||||
: (extra as Record<string, any> | null);
|
||||
return (
|
||||
<FlexRowContainer>
|
||||
{parsedExtra?.certification && (
|
||||
@@ -364,7 +473,7 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
row: {
|
||||
original: { kind },
|
||||
},
|
||||
}: any) => <DatasetTypeLabel datasetType={kind} />,
|
||||
}: CellProps<Dataset>) => <DatasetTypeLabel datasetType={kind} />,
|
||||
Header: t('Type'),
|
||||
accessor: 'kind',
|
||||
disableSortBy: true,
|
||||
@@ -372,12 +481,22 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
id: 'kind',
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { database },
|
||||
},
|
||||
}: CellProps<Dataset>) => database?.database_name || '-',
|
||||
Header: t('Database'),
|
||||
accessor: 'database.database_name',
|
||||
size: 'xl',
|
||||
id: 'database.database_name',
|
||||
},
|
||||
{
|
||||
Cell: ({
|
||||
row: {
|
||||
original: { schema },
|
||||
},
|
||||
}: CellProps<Dataset>) => schema || '-',
|
||||
Header: t('Schema'),
|
||||
accessor: 'schema',
|
||||
size: 'lg',
|
||||
@@ -394,7 +513,7 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
row: {
|
||||
original: { owners = [] },
|
||||
},
|
||||
}: any) => <FacePile users={owners} />,
|
||||
}: CellProps<Dataset>) => <FacePile users={owners} />,
|
||||
Header: t('Owners'),
|
||||
id: 'owners',
|
||||
disableSortBy: true,
|
||||
@@ -408,7 +527,9 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
changed_by: changedBy,
|
||||
},
|
||||
},
|
||||
}: any) => <ModifiedInfo date={changedOn} user={changedBy} />,
|
||||
}: CellProps<Dataset>) => (
|
||||
<ModifiedInfo date={changedOn} user={changedBy} />
|
||||
),
|
||||
Header: t('Last modified'),
|
||||
accessor: 'changed_on_delta_humanized',
|
||||
size: 'xl',
|
||||
@@ -421,16 +542,52 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
id: 'sql',
|
||||
},
|
||||
{
|
||||
Cell: ({ row: { original } }: any) => {
|
||||
// Verify owner or isAdmin
|
||||
accessor: 'source_type',
|
||||
hidden: true,
|
||||
disableSortBy: true,
|
||||
id: 'source_type',
|
||||
},
|
||||
{
|
||||
Cell: ({ row: { original } }: CellProps<Dataset>) => {
|
||||
const isSemanticView = original.kind === 'semantic_view';
|
||||
|
||||
// Semantic view: only show edit button
|
||||
if (isSemanticView) {
|
||||
if (!canEdit) return null;
|
||||
return (
|
||||
<Actions className="actions">
|
||||
<Tooltip
|
||||
id="edit-action-tooltip"
|
||||
title={t('Edit')}
|
||||
placement="bottom"
|
||||
>
|
||||
<span
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
className="action-button"
|
||||
onClick={() => setSvCurrentlyEditing(original)}
|
||||
>
|
||||
<Icons.EditOutlined iconSize="l" />
|
||||
</span>
|
||||
</Tooltip>
|
||||
</Actions>
|
||||
);
|
||||
}
|
||||
|
||||
// Dataset: full set of actions
|
||||
const allowEdit =
|
||||
original.owners.map((o: Owner) => o.id).includes(user.userId) ||
|
||||
isUserAdmin(user);
|
||||
original.owners
|
||||
.map((o: Owner) => o.id)
|
||||
.includes(Number(user.userId)) || isUserAdmin(user);
|
||||
|
||||
const handleEdit = () => openDatasetEditModal(original);
|
||||
const handleDelete = () => openDatasetDeleteModal(original);
|
||||
const handleExport = () => handleBulkDatasetExport([original]);
|
||||
const handleDuplicate = () => openDatasetDuplicateModal(original);
|
||||
const handleDuplicate = () => {
|
||||
if (original.kind === 'virtual' && original.sql) {
|
||||
openDatasetDuplicateModal(original as VirtualDataset);
|
||||
}
|
||||
};
|
||||
if (!canEdit && !canDelete && !canExport && !canDuplicate) {
|
||||
return null;
|
||||
}
|
||||
@@ -536,6 +693,22 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
|
||||
const filterTypes: ListViewFilters = useMemo(
|
||||
() => [
|
||||
...(isFeatureEnabled(FeatureFlag.SemanticLayers)
|
||||
? [
|
||||
{
|
||||
Header: t('Source'),
|
||||
key: 'source_type',
|
||||
id: 'source_type',
|
||||
input: 'select' as const,
|
||||
operator: FilterOperator.Equals,
|
||||
unfilteredLabel: t('All'),
|
||||
selects: [
|
||||
{ label: t('Database'), value: 'database' },
|
||||
{ label: t('Semantic Layer'), value: 'semantic_layer' },
|
||||
],
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
Header: t('Name'),
|
||||
key: 'search',
|
||||
@@ -543,18 +716,42 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
input: 'search',
|
||||
operator: FilterOperator.Contains,
|
||||
},
|
||||
{
|
||||
Header: t('Type'),
|
||||
key: 'sql',
|
||||
id: 'sql',
|
||||
input: 'select',
|
||||
operator: FilterOperator.DatasetIsNullOrEmpty,
|
||||
unfilteredLabel: 'All',
|
||||
selects: [
|
||||
{ label: t('Virtual'), value: false },
|
||||
{ label: t('Physical'), value: true },
|
||||
],
|
||||
},
|
||||
...(isFeatureEnabled(FeatureFlag.SemanticLayers)
|
||||
? [
|
||||
{
|
||||
Header: t('Type'),
|
||||
key: 'sql',
|
||||
id: 'sql',
|
||||
input: 'select' as const,
|
||||
operator: FilterOperator.DatasetIsNullOrEmpty,
|
||||
unfilteredLabel: 'All',
|
||||
selects: [
|
||||
...(currentSourceFilter !== 'semantic_layer'
|
||||
? [
|
||||
{ label: t('Physical'), value: true },
|
||||
{ label: t('Virtual'), value: false },
|
||||
]
|
||||
: []),
|
||||
...(currentSourceFilter !== 'database'
|
||||
? [{ label: t('Semantic View'), value: 'semantic_view' }]
|
||||
: []),
|
||||
],
|
||||
},
|
||||
]
|
||||
: [
|
||||
{
|
||||
Header: t('Type'),
|
||||
key: 'sql',
|
||||
id: 'sql',
|
||||
input: 'select' as const,
|
||||
operator: FilterOperator.DatasetIsNullOrEmpty,
|
||||
unfilteredLabel: 'All',
|
||||
selects: [
|
||||
{ label: t('Physical'), value: true },
|
||||
{ label: t('Virtual'), value: false },
|
||||
],
|
||||
},
|
||||
]),
|
||||
{
|
||||
Header: t('Database'),
|
||||
key: 'database',
|
||||
@@ -645,7 +842,7 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
dropdownStyle: { minWidth: WIDER_DROPDOWN_WIDTH },
|
||||
},
|
||||
],
|
||||
[user],
|
||||
[user, currentSourceFilter],
|
||||
);
|
||||
|
||||
const menuData: SubMenuProps = {
|
||||
@@ -893,10 +1090,18 @@ const DatasetList: FunctionComponent<DatasetListProps> = ({
|
||||
/>
|
||||
)}
|
||||
<DuplicateDatasetModal
|
||||
dataset={datasetCurrentlyDuplicating}
|
||||
dataset={datasetCurrentlyDuplicating as DatasetType | null}
|
||||
onHide={closeDatasetDuplicateModal}
|
||||
onDuplicate={handleDatasetDuplicate}
|
||||
/>
|
||||
<SemanticViewEditModal
|
||||
show={!!svCurrentlyEditing}
|
||||
onHide={() => setSvCurrentlyEditing(null)}
|
||||
onSave={refreshData}
|
||||
addDangerToast={addDangerToast}
|
||||
addSuccessToast={addSuccessToast}
|
||||
semanticView={svCurrentlyEditing}
|
||||
/>
|
||||
<ConfirmStatusChange
|
||||
title={t('Please confirm')}
|
||||
description={t(
|
||||
|
||||
@@ -25,11 +25,18 @@ export default interface Dataset {
|
||||
database: {
|
||||
id: string;
|
||||
database_name: string;
|
||||
};
|
||||
} | null;
|
||||
kind: string;
|
||||
source_type?: 'database' | 'semantic_layer';
|
||||
explore_url: string;
|
||||
id: number;
|
||||
owners: Array<Owner>;
|
||||
schema: string;
|
||||
schema: string | null;
|
||||
catalog?: string | null;
|
||||
table_name: string;
|
||||
description?: string | null;
|
||||
cache_timeout?: number | null;
|
||||
default_endpoint?: string | null;
|
||||
is_sqllab_view?: boolean;
|
||||
is_managed_externally?: boolean;
|
||||
}
|
||||
|
||||
0
superset/commands/datasource/__init__.py
Normal file
0
superset/commands/datasource/__init__.py
Normal file
157
superset/commands/datasource/list.py
Normal file
157
superset/commands/datasource/list.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Command for the combined dataset + semantic view list endpoint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import union_all
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.daos.datasource import DatasourceDAO
|
||||
from superset.datasource.schemas import DatasetListSchema, SemanticViewListSchema
|
||||
from superset.semantic_layers.models import SemanticView
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_dataset_schema = DatasetListSchema()
|
||||
_semantic_view_schema = SemanticViewListSchema()
|
||||
|
||||
|
||||
class GetCombinedDatasourceListCommand(BaseCommand):
|
||||
"""
|
||||
Fetch and serialize a paginated, combined list of datasets and semantic views.
|
||||
|
||||
Callers are responsible for checking access permissions before constructing
|
||||
this command and for passing the appropriate ``can_read_*`` flags.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: dict[str, Any],
|
||||
can_read_datasets: bool,
|
||||
can_read_semantic_views: bool,
|
||||
) -> None:
|
||||
self._args = args
|
||||
self._can_read_datasets = can_read_datasets
|
||||
self._can_read_semantic_views = can_read_semantic_views
|
||||
|
||||
def run(self) -> dict[str, Any]:
|
||||
self.validate()
|
||||
|
||||
page = self._args.get("page", 0)
|
||||
page_size = self._args.get("page_size", 25)
|
||||
order_column = self._args.get("order_column", "changed_on")
|
||||
order_direction = self._args.get("order_direction", "desc")
|
||||
filters = self._args.get("filters", [])
|
||||
|
||||
source_type, name_filter, sql_filter, type_filter = self._parse_filters(filters)
|
||||
source_type = self._resolve_source_type(source_type, sql_filter, type_filter)
|
||||
|
||||
ds_q = DatasourceDAO.build_dataset_query(name_filter, sql_filter)
|
||||
sv_q = DatasourceDAO.build_semantic_view_query(name_filter)
|
||||
|
||||
if source_type == "database":
|
||||
combined = ds_q.subquery()
|
||||
elif source_type == "semantic_layer":
|
||||
combined = sv_q.subquery()
|
||||
else:
|
||||
combined = union_all(ds_q, sv_q).subquery()
|
||||
|
||||
total_count, rows = DatasourceDAO.paginate_combined_query(
|
||||
combined, order_column, order_direction, page, page_size
|
||||
)
|
||||
|
||||
datasets_map = DatasourceDAO.fetch_datasets_by_ids(
|
||||
[r.item_id for r in rows if r.source_type == "database"]
|
||||
)
|
||||
sv_map = DatasourceDAO.fetch_semantic_views_by_ids(
|
||||
[r.item_id for r in rows if r.source_type == "semantic_layer"]
|
||||
)
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
if row.source_type == "database":
|
||||
ds_obj = cast(SqlaTable | None, datasets_map.get(row.item_id))
|
||||
if ds_obj:
|
||||
result.append(_dataset_schema.dump(ds_obj))
|
||||
else:
|
||||
sv_obj = cast(SemanticView | None, sv_map.get(row.item_id))
|
||||
if sv_obj:
|
||||
result.append(_semantic_view_schema.dump(sv_obj))
|
||||
|
||||
return {"count": total_count, "result": result}
|
||||
|
||||
def validate(self) -> None:
|
||||
pass # access checks are performed by the caller (API layer)
|
||||
|
||||
def _resolve_source_type(
|
||||
self,
|
||||
source_type: str,
|
||||
sql_filter: bool | None,
|
||||
type_filter: str | None,
|
||||
) -> str:
|
||||
"""Narrow source_type based on access flags, sql filter, and type filter."""
|
||||
if not self._can_read_semantic_views:
|
||||
return "database"
|
||||
if not self._can_read_datasets:
|
||||
return "semantic_layer"
|
||||
# sql_filter (physical/virtual toggle) only applies to datasets
|
||||
if sql_filter is not None:
|
||||
return "database"
|
||||
# Explicit semantic-view type filter
|
||||
if type_filter == "semantic_view":
|
||||
return "semantic_layer"
|
||||
return source_type
|
||||
|
||||
@staticmethod
|
||||
def _parse_filters(
|
||||
filters: list[dict[str, Any]],
|
||||
) -> tuple[str, str | None, bool | None, str | None]:
|
||||
"""
|
||||
Translate raw rison filter dicts into typed query parameters.
|
||||
|
||||
Returns:
|
||||
source_type: "all" | "database" | "semantic_layer"
|
||||
name_filter: substring to match against name/table_name
|
||||
sql_filter: True → physical only, False → virtual only, None → both
|
||||
type_filter: "semantic_view" when the caller wants only semantic views
|
||||
"""
|
||||
source_type = "all"
|
||||
name_filter: str | None = None
|
||||
sql_filter: bool | None = None
|
||||
type_filter: str | None = None
|
||||
|
||||
for f in filters:
|
||||
col = f.get("col")
|
||||
opr = f.get("opr")
|
||||
value = f.get("value")
|
||||
|
||||
if col == "source_type":
|
||||
source_type = value or "all"
|
||||
elif col == "table_name" and f.get("opr") == "ct":
|
||||
name_filter = value
|
||||
elif col == "sql":
|
||||
if opr == "dataset_is_null_or_empty" and value == "semantic_view":
|
||||
type_filter = "semantic_view"
|
||||
elif opr == "dataset_is_null_or_empty" and isinstance(value, bool):
|
||||
sql_filter = value
|
||||
|
||||
return source_type, name_filter, sql_filter, type_filter
|
||||
@@ -124,7 +124,11 @@ class GetExploreCommand(BaseCommand, ABC):
|
||||
security_manager.raise_for_access(datasource=datasource)
|
||||
|
||||
viz_type = form_data.get("viz_type")
|
||||
if not viz_type and datasource and datasource.default_endpoint:
|
||||
if (
|
||||
not viz_type
|
||||
and datasource
|
||||
and getattr(datasource, "default_endpoint", None)
|
||||
):
|
||||
raise WrongEndpointError(redirect=datasource.default_endpoint)
|
||||
|
||||
form_data["datasource"] = (
|
||||
|
||||
16
superset/commands/semantic_layer/__init__.py
Normal file
16
superset/commands/semantic_layer/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
64
superset/commands/semantic_layer/create.py
Normal file
64
superset/commands/semantic_layer/create.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.semantic_layer.exceptions import (
|
||||
SemanticLayerCreateFailedError,
|
||||
SemanticLayerInvalidError,
|
||||
)
|
||||
from superset.daos.semantic_layer import SemanticLayerDAO
|
||||
from superset.semantic_layers.registry import registry
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateSemanticLayerCommand(BaseCommand):
|
||||
def __init__(self, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(SQLAlchemyError, ValueError),
|
||||
reraise=SemanticLayerCreateFailedError,
|
||||
)
|
||||
)
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
return SemanticLayerDAO.create(attributes=self._properties)
|
||||
|
||||
def validate(self) -> None:
|
||||
sl_type = self._properties.get("type")
|
||||
if sl_type not in registry:
|
||||
raise SemanticLayerInvalidError(f"Unknown type: {sl_type}")
|
||||
|
||||
name: str = self._properties.get("name", "")
|
||||
if not SemanticLayerDAO.validate_uniqueness(name):
|
||||
raise SemanticLayerInvalidError(f"Name already exists: {name}")
|
||||
|
||||
# Validate configuration against the plugin
|
||||
cls = registry[sl_type]
|
||||
cls.from_configuration(self._properties["configuration"])
|
||||
56
superset/commands/semantic_layer/delete.py
Normal file
56
superset/commands/semantic_layer/delete.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.semantic_layer.exceptions import (
|
||||
SemanticLayerDeleteFailedError,
|
||||
SemanticLayerNotFoundError,
|
||||
)
|
||||
from superset.daos.semantic_layer import SemanticLayerDAO
|
||||
from superset.semantic_layers.models import SemanticLayer
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeleteSemanticLayerCommand(BaseCommand):
|
||||
def __init__(self, uuid: str):
|
||||
self._uuid = uuid
|
||||
self._model: SemanticLayer | None = None
|
||||
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(SQLAlchemyError,),
|
||||
reraise=SemanticLayerDeleteFailedError,
|
||||
)
|
||||
)
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
assert self._model
|
||||
SemanticLayerDAO.delete([self._model])
|
||||
|
||||
def validate(self) -> None:
|
||||
self._model = SemanticLayerDAO.find_by_uuid(self._uuid)
|
||||
if not self._model:
|
||||
raise SemanticLayerNotFoundError()
|
||||
68
superset/commands/semantic_layer/exceptions.py
Normal file
68
superset/commands/semantic_layer/exceptions.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
from superset.commands.exceptions import (
|
||||
CommandException,
|
||||
CommandInvalidError,
|
||||
CreateFailedError,
|
||||
DeleteFailedError,
|
||||
ForbiddenError,
|
||||
UpdateFailedError,
|
||||
)
|
||||
|
||||
|
||||
class SemanticViewNotFoundError(CommandException):
|
||||
status = 404
|
||||
message = _("Semantic view does not exist")
|
||||
|
||||
|
||||
class SemanticViewForbiddenError(ForbiddenError):
|
||||
message = _("Changing this semantic view is forbidden")
|
||||
|
||||
|
||||
class SemanticViewInvalidError(CommandInvalidError):
|
||||
message = _("Semantic view parameters are invalid.")
|
||||
|
||||
|
||||
class SemanticViewUpdateFailedError(UpdateFailedError):
|
||||
message = _("Semantic view could not be updated.")
|
||||
|
||||
|
||||
class SemanticLayerNotFoundError(CommandException):
|
||||
status = 404
|
||||
message = _("Semantic layer does not exist")
|
||||
|
||||
|
||||
class SemanticLayerForbiddenError(ForbiddenError):
|
||||
message = _("Changing this semantic layer is forbidden")
|
||||
|
||||
|
||||
class SemanticLayerInvalidError(CommandInvalidError):
|
||||
message = _("Semantic layer parameters are invalid.")
|
||||
|
||||
|
||||
class SemanticLayerCreateFailedError(CreateFailedError):
|
||||
message = _("Semantic layer could not be created.")
|
||||
|
||||
|
||||
class SemanticLayerUpdateFailedError(UpdateFailedError):
|
||||
message = _("Semantic layer could not be updated.")
|
||||
|
||||
|
||||
class SemanticLayerDeleteFailedError(DeleteFailedError):
|
||||
message = _("Semantic layer could not be deleted.")
|
||||
122
superset/commands/semantic_layer/update.py
Normal file
122
superset/commands/semantic_layer/update.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import security_manager
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.semantic_layer.exceptions import (
|
||||
SemanticLayerInvalidError,
|
||||
SemanticLayerNotFoundError,
|
||||
SemanticLayerUpdateFailedError,
|
||||
SemanticViewForbiddenError,
|
||||
SemanticViewNotFoundError,
|
||||
SemanticViewUpdateFailedError,
|
||||
)
|
||||
from superset.daos.semantic_layer import SemanticLayerDAO, SemanticViewDAO
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.semantic_layers.models import SemanticLayer, SemanticView
|
||||
from superset.semantic_layers.registry import registry
|
||||
from superset.utils import json
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpdateSemanticViewCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._model_id = model_id
|
||||
self._properties = data.copy()
|
||||
self._model: SemanticView | None = None
|
||||
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(SQLAlchemyError, ValueError),
|
||||
reraise=SemanticViewUpdateFailedError,
|
||||
)
|
||||
)
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
return SemanticViewDAO.update(self._model, attributes=self._properties)
|
||||
|
||||
def validate(self) -> None:
|
||||
self._model = SemanticViewDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise SemanticViewNotFoundError()
|
||||
|
||||
try:
|
||||
security_manager.raise_for_ownership(self._model)
|
||||
except SupersetSecurityException as ex:
|
||||
raise SemanticViewForbiddenError() from ex
|
||||
|
||||
name = self._properties.get("name", self._model.name)
|
||||
layer_uuid = str(self._model.semantic_layer_uuid)
|
||||
configuration = self._properties.get(
|
||||
"configuration",
|
||||
json.loads(self._model.configuration),
|
||||
)
|
||||
if not SemanticViewDAO.validate_update_uniqueness(
|
||||
view_uuid=str(self._model.uuid),
|
||||
name=name,
|
||||
layer_uuid=layer_uuid,
|
||||
configuration=configuration,
|
||||
):
|
||||
raise ValueError(
|
||||
f"A semantic view with name '{name}' and the same "
|
||||
"configuration already exists in this semantic layer."
|
||||
)
|
||||
|
||||
|
||||
class UpdateSemanticLayerCommand(BaseCommand):
|
||||
def __init__(self, uuid: str, data: dict[str, Any]):
|
||||
self._uuid = uuid
|
||||
self._properties = data.copy()
|
||||
self._model: SemanticLayer | None = None
|
||||
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(SQLAlchemyError, ValueError),
|
||||
reraise=SemanticLayerUpdateFailedError,
|
||||
)
|
||||
)
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
return SemanticLayerDAO.update(self._model, attributes=self._properties)
|
||||
|
||||
def validate(self) -> None:
|
||||
self._model = SemanticLayerDAO.find_by_uuid(self._uuid)
|
||||
if not self._model:
|
||||
raise SemanticLayerNotFoundError()
|
||||
|
||||
name = self._properties.get("name")
|
||||
if name and not SemanticLayerDAO.validate_update_uniqueness(self._uuid, name):
|
||||
raise SemanticLayerInvalidError(f"Name already exists: {name}")
|
||||
|
||||
if configuration := self._properties.get("configuration"):
|
||||
sl_type = self._model.type
|
||||
cls = registry[sl_type]
|
||||
cls.from_configuration(configuration)
|
||||
@@ -566,6 +566,9 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
|
||||
# can_copy_clipboard) instead of the single can_csv permission
|
||||
# @lifecycle: development
|
||||
"GRANULAR_EXPORT_CONTROLS": False,
|
||||
# Enable semantic layers and show semantic views alongside datasets
|
||||
# @lifecycle: development
|
||||
"SEMANTIC_LAYERS": False,
|
||||
# Enables advanced data type support
|
||||
# @lifecycle: development
|
||||
"ENABLE_ADVANCED_DATA_TYPES": False,
|
||||
|
||||
@@ -108,6 +108,8 @@ from superset.sql.parse import Table
|
||||
from superset.superset_typing import (
|
||||
AdhocColumn,
|
||||
AdhocMetric,
|
||||
DatasetColumnData,
|
||||
DatasetMetricData,
|
||||
ExplorableData,
|
||||
Metric,
|
||||
QueryObjectDict,
|
||||
@@ -464,8 +466,8 @@ class BaseDatasource(
|
||||
# sqla-specific
|
||||
"sql": self.sql,
|
||||
# one to many
|
||||
"columns": [o.data for o in self.columns],
|
||||
"metrics": [o.data for o in self.metrics],
|
||||
"columns": [cast(DatasetColumnData, o.data) for o in self.columns],
|
||||
"metrics": [cast(DatasetMetricData, o.data) for o in self.metrics],
|
||||
"folders": self.folders,
|
||||
# TODO deprecate, move logic to JS
|
||||
"order_by_choices": self.order_by_choices,
|
||||
|
||||
@@ -229,6 +229,40 @@ def inject_model_session_implementation() -> None:
|
||||
core_models_module.get_session = get_session
|
||||
|
||||
|
||||
def inject_semantic_layer_implementations() -> None:
|
||||
"""
|
||||
Replace abstract semantic layer decorator in
|
||||
superset_core.semantic_layers.decorators with a concrete implementation
|
||||
that registers classes in the contributions registry.
|
||||
"""
|
||||
import superset_core.semantic_layers.decorators as core_sl_module
|
||||
|
||||
import superset.extensions.context as context_module
|
||||
from superset.semantic_layers.registry import registry
|
||||
|
||||
def semantic_layer_impl(
|
||||
id: str,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
def decorator(cls: Any) -> Any:
|
||||
if context := context_module.get_current_extension_context():
|
||||
manifest = context.manifest
|
||||
prefixed_id = f"extensions.{manifest.publisher}.{manifest.name}.{id}"
|
||||
else:
|
||||
prefixed_id = id
|
||||
|
||||
cls.name = name
|
||||
cls.description = description
|
||||
cls._semantic_layer_id = prefixed_id
|
||||
registry[prefixed_id] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
core_sl_module.semantic_layer = semantic_layer_impl # type: ignore[assignment]
|
||||
|
||||
|
||||
def initialize_core_api_dependencies() -> None:
|
||||
"""
|
||||
Initialize all dependency injections for the superset-core API.
|
||||
@@ -242,3 +276,4 @@ def initialize_core_api_dependencies() -> None:
|
||||
inject_query_implementations()
|
||||
inject_task_implementations()
|
||||
inject_rest_api_implementations()
|
||||
inject_semantic_layer_implementations()
|
||||
|
||||
@@ -17,9 +17,14 @@
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from superset import db
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.sqla import models as sqla_models
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.daos.base import BaseDAO
|
||||
from superset.daos.exceptions import (
|
||||
@@ -28,11 +33,17 @@ from superset.daos.exceptions import (
|
||||
DatasourceValueIsIncorrect,
|
||||
)
|
||||
from superset.models.sql_lab import Query, SavedQuery
|
||||
from superset.semantic_layers.models import SemanticView
|
||||
from superset.utils.core import DatasourceType
|
||||
from superset.utils.filters import get_dataset_access_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Datasource = Union[SqlaTable, Query, SavedQuery]
|
||||
Datasource = Union[SqlaTable, Query, SavedQuery, SemanticView]
|
||||
|
||||
|
||||
def _escape_ilike_fragment(value: str) -> str:
|
||||
return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
class DatasourceDAO(BaseDAO[Datasource]):
|
||||
@@ -40,6 +51,7 @@ class DatasourceDAO(BaseDAO[Datasource]):
|
||||
DatasourceType.TABLE: SqlaTable,
|
||||
DatasourceType.QUERY: Query,
|
||||
DatasourceType.SAVEDQUERY: SavedQuery,
|
||||
DatasourceType.SEMANTIC_VIEW: SemanticView,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -78,3 +90,115 @@ class DatasourceDAO(BaseDAO[Datasource]):
|
||||
raise DatasourceNotFound()
|
||||
|
||||
return datasource
|
||||
|
||||
@staticmethod
|
||||
def build_dataset_query(
|
||||
name_filter: str | None,
|
||||
sql_filter: bool | None,
|
||||
) -> Select:
|
||||
"""Build a SELECT for datasets, applying access and content filters."""
|
||||
ds_q = select(
|
||||
SqlaTable.id.label("item_id"),
|
||||
literal("database").label("source_type"),
|
||||
SqlaTable.changed_on,
|
||||
SqlaTable.table_name,
|
||||
).select_from(SqlaTable.__table__)
|
||||
|
||||
if not security_manager.can_access_all_datasources():
|
||||
ds_q = ds_q.join(
|
||||
sqla_models.Database,
|
||||
sqla_models.Database.id == SqlaTable.database_id,
|
||||
)
|
||||
ds_q = ds_q.where(get_dataset_access_filters(SqlaTable))
|
||||
|
||||
if name_filter:
|
||||
escaped = _escape_ilike_fragment(name_filter)
|
||||
ds_q = ds_q.where(SqlaTable.table_name.ilike(f"%{escaped}%", escape="\\"))
|
||||
|
||||
if sql_filter is not None:
|
||||
if sql_filter:
|
||||
ds_q = ds_q.where(or_(SqlaTable.sql.is_(None), SqlaTable.sql == ""))
|
||||
else:
|
||||
ds_q = ds_q.where(and_(SqlaTable.sql.isnot(None), SqlaTable.sql != ""))
|
||||
|
||||
return ds_q
|
||||
|
||||
@staticmethod
|
||||
def build_semantic_view_query(name_filter: str | None) -> Select:
|
||||
"""Build a SELECT for semantic views, applying name filter."""
|
||||
sv_q = select(
|
||||
SemanticView.id.label("item_id"),
|
||||
literal("semantic_layer").label("source_type"),
|
||||
SemanticView.changed_on,
|
||||
SemanticView.name.label("table_name"),
|
||||
).select_from(SemanticView.__table__)
|
||||
|
||||
if name_filter:
|
||||
escaped = _escape_ilike_fragment(name_filter)
|
||||
sv_q = sv_q.where(SemanticView.name.ilike(f"%{escaped}%", escape="\\"))
|
||||
|
||||
return sv_q
|
||||
|
||||
@staticmethod
|
||||
def paginate_combined_query(
|
||||
combined: Any,
|
||||
order_column: str,
|
||||
order_direction: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> tuple[int, list[Any]]:
|
||||
"""Count, sort, and paginate the combined dataset/semantic-view query."""
|
||||
sort_col_map = {
|
||||
"changed_on": "changed_on",
|
||||
"changed_on_delta_humanized": "changed_on",
|
||||
"table_name": "table_name",
|
||||
}
|
||||
if order_column not in sort_col_map:
|
||||
raise ValueError(f"Invalid order column: {order_column}")
|
||||
sort_col_name = sort_col_map[order_column]
|
||||
|
||||
total_count = (
|
||||
db.session.execute(select(func.count()).select_from(combined)).scalar() or 0
|
||||
)
|
||||
|
||||
sort_col = combined.c[sort_col_name]
|
||||
ordered_col = sort_col.desc() if order_direction == "desc" else sort_col.asc()
|
||||
|
||||
rows = db.session.execute(
|
||||
select(combined.c.item_id, combined.c.source_type)
|
||||
.order_by(ordered_col)
|
||||
.offset(page * page_size)
|
||||
.limit(page_size)
|
||||
).fetchall()
|
||||
|
||||
return total_count, rows
|
||||
|
||||
@staticmethod
|
||||
def fetch_datasets_by_ids(ids: list[int]) -> dict[int, SqlaTable]:
|
||||
"""Fetch SqlaTable objects by id with relationships eager-loaded."""
|
||||
if not ids:
|
||||
return {}
|
||||
objs = (
|
||||
db.session.query(SqlaTable)
|
||||
.options(
|
||||
joinedload(SqlaTable.database),
|
||||
joinedload(SqlaTable.owners),
|
||||
joinedload(SqlaTable.changed_by),
|
||||
)
|
||||
.filter(SqlaTable.id.in_(ids))
|
||||
.all()
|
||||
)
|
||||
return {obj.id: obj for obj in objs}
|
||||
|
||||
@staticmethod
|
||||
def fetch_semantic_views_by_ids(ids: list[int]) -> dict[int, SemanticView]:
|
||||
"""Fetch SemanticView objects by id with relationships eager-loaded."""
|
||||
if not ids:
|
||||
return {}
|
||||
objs = (
|
||||
db.session.query(SemanticView)
|
||||
.options(joinedload(SemanticView.changed_by))
|
||||
.filter(SemanticView.id.in_(ids))
|
||||
.all()
|
||||
)
|
||||
return {obj.id: obj for obj in objs}
|
||||
|
||||
198
superset/daos/semantic_layer.py
Normal file
198
superset/daos/semantic_layer.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""DAOs for semantic layer models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.exc import StatementError
|
||||
|
||||
from superset_core.semantic_layers.daos import (
|
||||
AbstractSemanticLayerDAO,
|
||||
AbstractSemanticViewDAO,
|
||||
)
|
||||
|
||||
from superset.extensions import db
|
||||
from superset.semantic_layers.models import SemanticLayer, SemanticView
|
||||
from superset.utils import json
|
||||
|
||||
|
||||
class SemanticLayerDAO(AbstractSemanticLayerDAO):
|
||||
"""
|
||||
Data Access Object for SemanticLayer model.
|
||||
"""
|
||||
|
||||
model_cls = SemanticLayer
|
||||
|
||||
@staticmethod
|
||||
def find_by_uuid(uuid_str: str) -> SemanticLayer | None:
|
||||
try:
|
||||
return (
|
||||
db.session.query(SemanticLayer)
|
||||
.filter(SemanticLayer.uuid == uuid_str)
|
||||
.one_or_none()
|
||||
)
|
||||
except (ValueError, StatementError):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_all(cls, skip_base_filter: bool = False) -> list[SemanticLayer]:
|
||||
query = db.session.query(SemanticLayer)
|
||||
query = cls._apply_base_filter(query, skip_base_filter)
|
||||
return query.all()
|
||||
|
||||
@classmethod
|
||||
def validate_uniqueness(cls, name: str) -> bool:
|
||||
"""
|
||||
Validate that semantic layer name is unique.
|
||||
|
||||
:param name: Semantic layer name
|
||||
:return: True if name is unique, False otherwise
|
||||
"""
|
||||
query = db.session.query(SemanticLayer).filter(SemanticLayer.name == name)
|
||||
return not db.session.query(query.exists()).scalar()
|
||||
|
||||
@classmethod
|
||||
def validate_update_uniqueness(cls, layer_uuid: str, name: str) -> bool:
|
||||
"""
|
||||
Validate that semantic layer name is unique for updates.
|
||||
|
||||
:param layer_uuid: UUID of the semantic layer being updated
|
||||
:param name: New name to validate
|
||||
:return: True if name is unique, False otherwise
|
||||
"""
|
||||
query = db.session.query(SemanticLayer).filter(
|
||||
SemanticLayer.name == name,
|
||||
SemanticLayer.uuid != layer_uuid,
|
||||
)
|
||||
return not db.session.query(query.exists()).scalar()
|
||||
|
||||
@classmethod
|
||||
def find_by_name(cls, name: str) -> SemanticLayer | None:
|
||||
"""
|
||||
Find semantic layer by name.
|
||||
|
||||
:param name: Semantic layer name
|
||||
:return: SemanticLayer instance or None
|
||||
"""
|
||||
return (
|
||||
db.session.query(SemanticLayer)
|
||||
.filter(SemanticLayer.name == name)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_semantic_views(cls, layer_uuid: str) -> list[SemanticView]:
|
||||
"""
|
||||
Get all semantic views for a semantic layer.
|
||||
|
||||
:param layer_uuid: UUID of the semantic layer
|
||||
:return: List of SemanticView instances
|
||||
"""
|
||||
return (
|
||||
db.session.query(SemanticView)
|
||||
.filter(SemanticView.semantic_layer_uuid == layer_uuid)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class SemanticViewDAO(AbstractSemanticViewDAO):
|
||||
"""Data Access Object for SemanticView model."""
|
||||
|
||||
model_cls = SemanticView
|
||||
|
||||
@classmethod
|
||||
def validate_uniqueness(
|
||||
cls,
|
||||
name: str,
|
||||
layer_uuid: str,
|
||||
configuration: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that view is unique within a semantic layer.
|
||||
|
||||
Uniqueness is determined by name, layer, and configuration.
|
||||
The configuration column is encrypted (non-deterministic
|
||||
ciphertext), so it cannot be compared at the DB level. Instead,
|
||||
we filter by name + layer in SQL and compare decrypted
|
||||
configuration dicts in Python.
|
||||
|
||||
:param name: View name
|
||||
:param layer_uuid: UUID of the semantic layer
|
||||
:param configuration: Configuration dict to compare
|
||||
:return: True if unique, False otherwise
|
||||
"""
|
||||
candidates = (
|
||||
db.session.query(SemanticView)
|
||||
.filter(
|
||||
SemanticView.name == name,
|
||||
SemanticView.semantic_layer_uuid == layer_uuid,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return not any(json.loads(c.configuration) == configuration for c in candidates)
|
||||
|
||||
@classmethod
|
||||
def validate_update_uniqueness(
|
||||
cls,
|
||||
view_uuid: str,
|
||||
name: str,
|
||||
layer_uuid: str,
|
||||
configuration: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that view is unique within a semantic layer for updates.
|
||||
|
||||
Same logic as ``validate_uniqueness`` but excludes the view
|
||||
being updated.
|
||||
|
||||
:param view_uuid: UUID of the view being updated
|
||||
:param name: New name to validate
|
||||
:param layer_uuid: UUID of the semantic layer
|
||||
:param configuration: Configuration dict to compare
|
||||
:return: True if unique, False otherwise
|
||||
"""
|
||||
candidates = (
|
||||
db.session.query(SemanticView)
|
||||
.filter(
|
||||
SemanticView.name == name,
|
||||
SemanticView.semantic_layer_uuid == layer_uuid,
|
||||
SemanticView.uuid != view_uuid,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return not any(json.loads(c.configuration) == configuration for c in candidates)
|
||||
|
||||
@classmethod
|
||||
def find_by_name(cls, name: str, layer_uuid: str) -> SemanticView | None:
|
||||
"""
|
||||
Find semantic view by name within a semantic layer.
|
||||
|
||||
:param name: View name
|
||||
:param layer_uuid: UUID of the semantic layer
|
||||
:return: SemanticView instance or None
|
||||
"""
|
||||
return (
|
||||
db.session.query(SemanticView)
|
||||
.filter(
|
||||
SemanticView.name == name,
|
||||
SemanticView.semantic_layer_uuid == layer_uuid,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
@@ -15,11 +15,14 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app as app, request
|
||||
from flask_appbuilder.api import expose, protect, safe
|
||||
from flask_appbuilder.api import expose, protect, rison, safe
|
||||
from flask_appbuilder.api.schemas import get_list_schema
|
||||
|
||||
from superset import event_logger
|
||||
from superset import event_logger, is_feature_enabled, security_manager
|
||||
from superset.commands.datasource.list import GetCombinedDatasourceListCommand
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
from superset.daos.datasource import DatasourceDAO
|
||||
from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
|
||||
@@ -303,3 +306,53 @@ class DatasourceRestApi(BaseSupersetApi):
|
||||
f"Invalid expression type: {expression_type}. "
|
||||
f"Valid types are: column, metric, where, having"
|
||||
) from None
|
||||
|
||||
@expose("/", methods=("GET",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@rison(get_list_schema)
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.combined_list",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def combined_list(self, **kwargs: Any) -> FlaskResponse:
|
||||
"""List datasets and semantic views combined.
|
||||
---
|
||||
get:
|
||||
summary: List datasets and semantic views combined
|
||||
parameters:
|
||||
- in: query
|
||||
name: q
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/get_list_schema'
|
||||
responses:
|
||||
200:
|
||||
description: Combined list of datasets and semantic views
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
can_read_datasets = security_manager.can_access("can_read", "Dataset")
|
||||
can_read_sv = is_feature_enabled(
|
||||
"SEMANTIC_LAYERS"
|
||||
) and security_manager.can_access("can_read", "SemanticView")
|
||||
|
||||
if not can_read_datasets and not can_read_sv:
|
||||
return self.response(403, message="Access denied")
|
||||
|
||||
try:
|
||||
result = GetCombinedDatasourceListCommand(
|
||||
args=kwargs.get("rison", {}),
|
||||
can_read_datasets=can_read_datasets,
|
||||
can_read_semantic_views=can_read_sv,
|
||||
).run()
|
||||
except ValueError as ex:
|
||||
return self.response(400, message=str(ex))
|
||||
|
||||
return self.response(200, **result)
|
||||
|
||||
147
superset/datasource/schemas.py
Normal file
147
superset/datasource/schemas.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Marshmallow schemas for the combined datasource list endpoint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from marshmallow import fields, Schema
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.semantic_layers.models import SemanticView
|
||||
|
||||
|
||||
class _ChangedBySchema(Schema):
|
||||
first_name = fields.String()
|
||||
last_name = fields.String()
|
||||
|
||||
|
||||
class _OwnerSchema(Schema):
|
||||
id = fields.Integer()
|
||||
first_name = fields.String()
|
||||
last_name = fields.String()
|
||||
|
||||
|
||||
class _DatabaseSchema(Schema):
|
||||
id = fields.Integer()
|
||||
database_name = fields.String()
|
||||
|
||||
|
||||
class DatasetListSchema(Schema):
|
||||
"""Serializes a SqlaTable ORM object for the combined list response."""
|
||||
|
||||
id = fields.Integer()
|
||||
uuid = fields.Method("get_uuid")
|
||||
table_name = fields.String()
|
||||
kind = fields.String()
|
||||
source_type = fields.Constant("database")
|
||||
description = fields.String(allow_none=True)
|
||||
explore_url = fields.String()
|
||||
database = fields.Method("get_database")
|
||||
catalog = fields.String(allow_none=True)
|
||||
schema = fields.String(allow_none=True)
|
||||
sql = fields.String(allow_none=True)
|
||||
extra = fields.Raw(allow_none=True)
|
||||
default_endpoint = fields.String(allow_none=True)
|
||||
is_sqllab_view = fields.Boolean(allow_none=True)
|
||||
is_managed_externally = fields.Boolean(allow_none=True)
|
||||
owners = fields.Method("get_owners")
|
||||
changed_by_name = fields.String()
|
||||
changed_by = fields.Method("get_changed_by")
|
||||
changed_on_delta_humanized = fields.Method("get_changed_on_delta_humanized")
|
||||
changed_on_utc = fields.Method("get_changed_on_utc")
|
||||
|
||||
def get_uuid(self, obj: SqlaTable) -> str:
|
||||
return str(obj.uuid)
|
||||
|
||||
def get_database(self, obj: SqlaTable) -> dict[str, object] | None:
|
||||
if not obj.database:
|
||||
return None
|
||||
return _DatabaseSchema().dump(
|
||||
{"id": obj.database_id, "database_name": obj.database.database_name}
|
||||
)
|
||||
|
||||
def get_owners(self, obj: SqlaTable) -> list[dict[str, object]]:
|
||||
return _OwnerSchema(many=True).dump(
|
||||
[
|
||||
{"id": o.id, "first_name": o.first_name, "last_name": o.last_name}
|
||||
for o in obj.owners
|
||||
]
|
||||
)
|
||||
|
||||
def get_changed_by(self, obj: SqlaTable) -> dict[str, object] | None:
|
||||
if not obj.changed_by:
|
||||
return None
|
||||
return _ChangedBySchema().dump(
|
||||
{
|
||||
"first_name": obj.changed_by.first_name,
|
||||
"last_name": obj.changed_by.last_name,
|
||||
}
|
||||
)
|
||||
|
||||
def get_changed_on_delta_humanized(self, obj: SqlaTable) -> str:
|
||||
return obj.changed_on_delta_humanized()
|
||||
|
||||
def get_changed_on_utc(self, obj: SqlaTable) -> str:
|
||||
return obj.changed_on_utc()
|
||||
|
||||
|
||||
class SemanticViewListSchema(Schema):
|
||||
"""Serializes a SemanticView ORM object for the combined list response."""
|
||||
|
||||
id = fields.Integer()
|
||||
uuid = fields.Method("get_uuid")
|
||||
table_name = fields.Method("get_table_name")
|
||||
kind = fields.Constant("semantic_view")
|
||||
source_type = fields.Constant("semantic_layer")
|
||||
description = fields.String(allow_none=True)
|
||||
explore_url = fields.String()
|
||||
database = fields.Constant(None)
|
||||
catalog = fields.Constant(None)
|
||||
schema = fields.Constant(None)
|
||||
sql = fields.Constant(None)
|
||||
extra = fields.Constant(None)
|
||||
default_endpoint = fields.Constant(None)
|
||||
is_sqllab_view = fields.Constant(False)
|
||||
is_managed_externally = fields.Constant(False)
|
||||
owners = fields.Constant([])
|
||||
changed_by_name = fields.String()
|
||||
changed_by = fields.Method("get_changed_by")
|
||||
changed_on_delta_humanized = fields.Method("get_changed_on_delta_humanized")
|
||||
changed_on_utc = fields.Method("get_changed_on_utc")
|
||||
cache_timeout = fields.Integer(allow_none=True)
|
||||
|
||||
def get_uuid(self, obj: SemanticView) -> str:
|
||||
return str(obj.uuid)
|
||||
|
||||
def get_table_name(self, obj: SemanticView) -> str:
|
||||
return obj.name
|
||||
|
||||
def get_changed_by(self, obj: SemanticView) -> dict[str, object] | None:
|
||||
if not obj.changed_by:
|
||||
return None
|
||||
return _ChangedBySchema().dump(
|
||||
{
|
||||
"first_name": obj.changed_by.first_name,
|
||||
"last_name": obj.changed_by.last_name,
|
||||
}
|
||||
)
|
||||
|
||||
def get_changed_on_delta_humanized(self, obj: SemanticView) -> str:
|
||||
return obj.changed_on_delta_humanized()
|
||||
|
||||
def get_changed_on_utc(self, obj: SemanticView) -> str:
|
||||
return obj.changed_on_utc()
|
||||
@@ -53,6 +53,130 @@ class TimeGrainDict(TypedDict):
|
||||
duration: str | None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MetricMetadata(Protocol):
|
||||
"""
|
||||
Protocol for metric metadata objects.
|
||||
|
||||
Represents a metric that's available on an explorable data source.
|
||||
Metrics contain SQL expressions or references to semantic layer measures.
|
||||
|
||||
Attributes:
|
||||
metric_name: Unique identifier for the metric
|
||||
expression: SQL expression or reference for calculating the metric
|
||||
verbose_name: Human-readable name for display in the UI
|
||||
description: Description of what the metric represents
|
||||
d3format: D3 format string for formatting numeric values
|
||||
currency: Currency configuration for the metric (JSON object)
|
||||
warning_text: Warning message to display when using this metric
|
||||
certified_by: Person or entity that certified this metric
|
||||
certification_details: Details about the certification
|
||||
"""
|
||||
|
||||
@property
|
||||
def metric_name(self) -> str:
|
||||
"""Unique identifier for the metric."""
|
||||
|
||||
@property
|
||||
def expression(self) -> str:
|
||||
"""SQL expression or reference for calculating the metric."""
|
||||
|
||||
@property
|
||||
def verbose_name(self) -> str | None:
|
||||
"""Human-readable name for display in the UI."""
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
"""Description of what the metric represents."""
|
||||
|
||||
@property
|
||||
def d3format(self) -> str | None:
|
||||
"""D3 format string for formatting numeric values."""
|
||||
|
||||
@property
|
||||
def currency(self) -> dict[str, Any] | None:
|
||||
"""Currency configuration for the metric (JSON object)."""
|
||||
|
||||
@property
|
||||
def warning_text(self) -> str | None:
|
||||
"""Warning message to display when using this metric."""
|
||||
|
||||
@property
|
||||
def certified_by(self) -> str | None:
|
||||
"""Person or entity that certified this metric."""
|
||||
|
||||
@property
|
||||
def certification_details(self) -> str | None:
|
||||
"""Details about the certification."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ColumnMetadata(Protocol):
|
||||
"""
|
||||
Protocol for column metadata objects.
|
||||
|
||||
Represents a column/dimension that's available on an explorable data source.
|
||||
Used for grouping, filtering, and dimension-based analysis.
|
||||
|
||||
Attributes:
|
||||
column_name: Unique identifier for the column
|
||||
type: SQL data type of the column (e.g., 'VARCHAR', 'INTEGER', 'DATETIME')
|
||||
is_dttm: Whether this column represents a date or time value
|
||||
verbose_name: Human-readable name for display in the UI
|
||||
description: Description of what the column represents
|
||||
groupby: Whether this column is allowed for grouping/aggregation
|
||||
filterable: Whether this column can be used in filters
|
||||
expression: SQL expression if this is a calculated column
|
||||
python_date_format: Python datetime format string for temporal columns
|
||||
advanced_data_type: Advanced data type classification
|
||||
extra: Additional metadata stored as JSON
|
||||
"""
|
||||
|
||||
@property
|
||||
def column_name(self) -> str:
|
||||
"""Unique identifier for the column."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""SQL data type of the column."""
|
||||
|
||||
@property
|
||||
def is_dttm(self) -> bool:
|
||||
"""Whether this column represents a date or time value."""
|
||||
|
||||
@property
|
||||
def verbose_name(self) -> str | None:
|
||||
"""Human-readable name for display in the UI."""
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
"""Description of what the column represents."""
|
||||
|
||||
@property
|
||||
def groupby(self) -> bool:
|
||||
"""Whether this column is allowed for grouping/aggregation."""
|
||||
|
||||
@property
|
||||
def filterable(self) -> bool:
|
||||
"""Whether this column can be used in filters."""
|
||||
|
||||
@property
|
||||
def expression(self) -> str | None:
|
||||
"""SQL expression if this is a calculated column."""
|
||||
|
||||
@property
|
||||
def python_date_format(self) -> str | None:
|
||||
"""Python datetime format string for temporal columns."""
|
||||
|
||||
@property
|
||||
def advanced_data_type(self) -> str | None:
|
||||
"""Advanced data type classification."""
|
||||
|
||||
@property
|
||||
def extra(self) -> str | None:
|
||||
"""Additional metadata stored as JSON."""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Explorable(Protocol):
|
||||
"""
|
||||
@@ -144,7 +268,7 @@ class Explorable(Protocol):
|
||||
"""
|
||||
|
||||
@property
|
||||
def metrics(self) -> list[Any]:
|
||||
def metrics(self) -> list[MetricMetadata]:
|
||||
"""
|
||||
List of metric metadata objects.
|
||||
|
||||
@@ -159,7 +283,7 @@ class Explorable(Protocol):
|
||||
|
||||
# TODO: rename to dimensions
|
||||
@property
|
||||
def columns(self) -> list[Any]:
|
||||
def columns(self) -> list[ColumnMetadata]:
|
||||
"""
|
||||
List of column metadata objects.
|
||||
|
||||
|
||||
@@ -268,6 +268,14 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
appbuilder.add_api(ReportExecutionLogRestApi)
|
||||
appbuilder.add_api(RLSRestApi)
|
||||
appbuilder.add_api(SavedQueryRestApi)
|
||||
if feature_flag_manager.is_feature_enabled("SEMANTIC_LAYERS"):
|
||||
from superset.semantic_layers.api import (
|
||||
SemanticLayerRestApi,
|
||||
SemanticViewRestApi,
|
||||
)
|
||||
|
||||
appbuilder.add_api(SemanticLayerRestApi)
|
||||
appbuilder.add_api(SemanticViewRestApi)
|
||||
appbuilder.add_api(TagRestApi)
|
||||
appbuilder.add_api(SqlLabRestApi)
|
||||
appbuilder.add_api(SqlLabPermalinkRestApi)
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""add_semantic_layers_and_views
|
||||
|
||||
Revision ID: 33d7e0e21daa
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2025-11-04 11:26:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy_utils import UUIDType
|
||||
from sqlalchemy_utils.types.json import JSONType
|
||||
|
||||
from superset.extensions import encrypted_field_factory
|
||||
from superset.migrations.shared.utils import (
|
||||
create_fks_for_table,
|
||||
create_table,
|
||||
drop_table,
|
||||
)
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33d7e0e21daa"
|
||||
down_revision = "a1b2c3d4e5f6"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create semantic_layers table
|
||||
create_table(
|
||||
"semantic_layers",
|
||||
sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False),
|
||||
# created_on and changed_on are nullable=True to match AuditMixinNullable
|
||||
sa.Column("created_on", sa.DateTime(), nullable=False),
|
||||
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("name", sa.String(length=250), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("type", sa.String(length=250), nullable=False),
|
||||
sa.Column(
|
||||
"configuration",
|
||||
encrypted_field_factory.create(JSONType),
|
||||
nullable=True,
|
||||
),
|
||||
# configuration_version tracks the schema version of the configuration
|
||||
# JSON field to aid with migrations as the schema evolves over time.
|
||||
sa.Column(
|
||||
"configuration_version",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="1",
|
||||
),
|
||||
sa.Column("cache_timeout", sa.Integer(), nullable=True),
|
||||
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("uuid"),
|
||||
)
|
||||
|
||||
# Create foreign key constraints for semantic_layers
|
||||
create_fks_for_table(
|
||||
"fk_semantic_layers_created_by_fk_ab_user",
|
||||
"semantic_layers",
|
||||
"ab_user",
|
||||
["created_by_fk"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
"fk_semantic_layers_changed_by_fk_ab_user",
|
||||
"semantic_layers",
|
||||
"ab_user",
|
||||
["changed_by_fk"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Create semantic_views table.
|
||||
# The integer `id` is the primary key (auto-increment across all supported
|
||||
# databases) and `uuid` is a secondary unique identifier. This follows the
|
||||
# standard Superset model pattern and avoids using sa.Identity(), which is
|
||||
# not supported in MySQL or SQLite.
|
||||
create_table(
|
||||
"semantic_views",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False),
|
||||
# created_on and changed_on are nullable=True to match AuditMixinNullable
|
||||
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("name", sa.String(length=250), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"configuration",
|
||||
encrypted_field_factory.create(JSONType),
|
||||
nullable=True,
|
||||
),
|
||||
# configuration_version tracks the schema version of the configuration
|
||||
# JSON field to aid with migrations as the schema evolves over time.
|
||||
sa.Column(
|
||||
"configuration_version",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="1",
|
||||
),
|
||||
sa.Column("cache_timeout", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"semantic_layer_uuid",
|
||||
UUIDType(binary=True),
|
||||
sa.ForeignKey("semantic_layers.uuid", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("uuid"),
|
||||
)
|
||||
|
||||
# Create foreign key constraints for semantic_views
|
||||
create_fks_for_table(
|
||||
"fk_semantic_views_created_by_fk_ab_user",
|
||||
"semantic_views",
|
||||
"ab_user",
|
||||
["created_by_fk"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
create_fks_for_table(
|
||||
"fk_semantic_views_changed_by_fk_ab_user",
|
||||
"semantic_views",
|
||||
"ab_user",
|
||||
["changed_by_fk"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Update chart datasource constraint to allow semantic_view
|
||||
with op.batch_alter_table("slices") as batch_op:
|
||||
batch_op.drop_constraint("ck_chart_datasource", type_="check")
|
||||
batch_op.create_check_constraint(
|
||||
"ck_chart_datasource",
|
||||
"datasource_type in ('table', 'semantic_view')",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore original constraint
|
||||
with op.batch_alter_table("slices") as batch_op:
|
||||
batch_op.drop_constraint("ck_chart_datasource", type_="check")
|
||||
batch_op.create_check_constraint(
|
||||
"ck_chart_datasource", "datasource_type in ('table')"
|
||||
)
|
||||
|
||||
drop_table("semantic_views")
|
||||
drop_table("semantic_layers")
|
||||
@@ -22,7 +22,7 @@ import logging
|
||||
import re
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, cast, Optional, TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sqla
|
||||
from flask import current_app as app
|
||||
@@ -67,7 +67,7 @@ from superset.sql.parse import (
|
||||
Table,
|
||||
)
|
||||
from superset.sqllab.limiting_factor import LimitingFactor
|
||||
from superset.superset_typing import ExplorableData, QueryObjectDict
|
||||
from superset.superset_typing import DatasetColumnData, ExplorableData, QueryObjectDict
|
||||
from superset.utils import json
|
||||
from superset.utils.core import (
|
||||
get_column_name,
|
||||
@@ -261,7 +261,7 @@ class Query(
|
||||
],
|
||||
"filter_select": True,
|
||||
"name": self.tab_name,
|
||||
"columns": [o.data for o in self.columns],
|
||||
"columns": [cast(DatasetColumnData, o.data) for o in self.columns],
|
||||
"metrics": [],
|
||||
"id": self.id,
|
||||
"type": self.type,
|
||||
|
||||
16
superset/semantic_layers/__init__.py
Normal file
16
superset/semantic_layers/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
490
superset/semantic_layers/api.py
Normal file
490
superset/semantic_layers/api.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import request, Response
|
||||
from flask_appbuilder.api import expose, protect, safe
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset import event_logger
|
||||
from superset.commands.semantic_layer.create import CreateSemanticLayerCommand
|
||||
from superset.commands.semantic_layer.delete import DeleteSemanticLayerCommand
|
||||
from superset.commands.semantic_layer.exceptions import (
|
||||
SemanticLayerCreateFailedError,
|
||||
SemanticLayerDeleteFailedError,
|
||||
SemanticLayerInvalidError,
|
||||
SemanticLayerNotFoundError,
|
||||
SemanticLayerUpdateFailedError,
|
||||
SemanticViewForbiddenError,
|
||||
SemanticViewInvalidError,
|
||||
SemanticViewNotFoundError,
|
||||
SemanticViewUpdateFailedError,
|
||||
)
|
||||
from superset.commands.semantic_layer.update import (
|
||||
UpdateSemanticLayerCommand,
|
||||
UpdateSemanticViewCommand,
|
||||
)
|
||||
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP
|
||||
from superset.daos.semantic_layer import SemanticLayerDAO
|
||||
from superset.semantic_layers.models import SemanticLayer, SemanticView
|
||||
from superset.semantic_layers.registry import registry
|
||||
from superset.semantic_layers.schemas import (
|
||||
SemanticLayerPostSchema,
|
||||
SemanticLayerPutSchema,
|
||||
SemanticViewPutSchema,
|
||||
)
|
||||
from superset.superset_typing import FlaskResponse
|
||||
from superset.views.base_api import (
|
||||
BaseSupersetApi,
|
||||
BaseSupersetModelRestApi,
|
||||
requires_json,
|
||||
statsd_metrics,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _serialize_layer(layer: SemanticLayer) -> dict[str, Any]:
|
||||
return {
|
||||
"uuid": str(layer.uuid),
|
||||
"name": layer.name,
|
||||
"description": layer.description,
|
||||
"type": layer.type,
|
||||
"cache_timeout": layer.cache_timeout,
|
||||
}
|
||||
|
||||
|
||||
class SemanticViewRestApi(BaseSupersetModelRestApi):
|
||||
datamodel = SQLAInterface(SemanticView)
|
||||
|
||||
resource_name = "semantic_view"
|
||||
allow_browser_login = True
|
||||
class_permission_name = "SemanticView"
|
||||
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
|
||||
include_route_methods = {"put"}
|
||||
|
||||
edit_model_schema = SemanticViewPutSchema()
|
||||
|
||||
@expose("/<pk>", methods=("PUT",))
|
||||
@protect()
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
@requires_json
|
||||
def put(self, pk: int) -> Response:
|
||||
"""Update a semantic view.
|
||||
---
|
||||
put:
|
||||
summary: Update a semantic view
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: integer
|
||||
name: pk
|
||||
requestBody:
|
||||
description: Semantic view schema
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/{{self.__class__.__name__}}.put'
|
||||
responses:
|
||||
200:
|
||||
description: Semantic view changed
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: number
|
||||
result:
|
||||
$ref: '#/components/schemas/{{self.__class__.__name__}}.put'
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
item = self.edit_model_schema.load(request.json)
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=error.messages)
|
||||
try:
|
||||
changed_model = UpdateSemanticViewCommand(pk, item).run()
|
||||
response = self.response(200, id=changed_model.id, result=item)
|
||||
except SemanticViewNotFoundError:
|
||||
response = self.response_404()
|
||||
except SemanticViewForbiddenError:
|
||||
response = self.response_403()
|
||||
except SemanticViewInvalidError as ex:
|
||||
response = self.response_422(message=ex.normalized_messages())
|
||||
except SemanticViewUpdateFailedError as ex:
|
||||
logger.error(
|
||||
"Error updating model %s: %s",
|
||||
self.__class__.__name__,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
response = self.response_422(message=str(ex))
|
||||
return response
|
||||
|
||||
|
||||
class SemanticLayerRestApi(BaseSupersetApi):
|
||||
resource_name = "semantic_layer"
|
||||
allow_browser_login = True
|
||||
class_permission_name = "SemanticLayer"
|
||||
method_permission_name = {
|
||||
**MODEL_API_RW_METHOD_PERMISSION_MAP,
|
||||
"types": "read",
|
||||
"configuration_schema": "read",
|
||||
"runtime_schema": "read",
|
||||
}
|
||||
openapi_spec_tag = "Semantic Layers"
|
||||
add_model_schema = SemanticLayerPostSchema()
|
||||
edit_model_schema = SemanticLayerPutSchema()
|
||||
|
||||
@expose("/types", methods=("GET",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
def types(self) -> FlaskResponse:
|
||||
"""List available semantic layer types.
|
||||
---
|
||||
get:
|
||||
summary: List available semantic layer types
|
||||
responses:
|
||||
200:
|
||||
description: A list of semantic layer types
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
"""
|
||||
result = [
|
||||
{"id": key, "name": cls.name, "description": cls.description} # type: ignore[attr-defined]
|
||||
for key, cls in registry.items()
|
||||
]
|
||||
return self.response(200, result=result)
|
||||
|
||||
@expose("/schema/configuration", methods=("POST",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@requires_json
|
||||
def configuration_schema(self) -> FlaskResponse:
|
||||
"""Get configuration schema for a semantic layer type.
|
||||
---
|
||||
post:
|
||||
summary: Get configuration schema for a semantic layer type
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
configuration:
|
||||
type: object
|
||||
responses:
|
||||
200:
|
||||
description: Configuration JSON Schema
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
"""
|
||||
body = request.json or {}
|
||||
sl_type = body.get("type")
|
||||
|
||||
cls = registry.get(sl_type) # type: ignore[arg-type]
|
||||
if not cls:
|
||||
return self.response_400(message=f"Unknown type: {sl_type}")
|
||||
|
||||
parsed_config = None
|
||||
if config := body.get("configuration"):
|
||||
try:
|
||||
parsed_config = cls.from_configuration(config).configuration # type: ignore[attr-defined]
|
||||
except Exception: # pylint: disable=broad-except
|
||||
parsed_config = None
|
||||
|
||||
schema = cls.get_configuration_schema(parsed_config)
|
||||
return self.response(200, result=schema)
|
||||
|
||||
@expose("/<uuid>/schema/runtime", methods=("POST",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
def runtime_schema(self, uuid: str) -> FlaskResponse:
|
||||
"""Get runtime schema for a stored semantic layer.
|
||||
---
|
||||
post:
|
||||
summary: Get runtime schema for a semantic layer
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: string
|
||||
name: uuid
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
runtime_data:
|
||||
type: object
|
||||
responses:
|
||||
200:
|
||||
description: Runtime JSON Schema
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
"""
|
||||
layer = SemanticLayerDAO.find_by_uuid(uuid)
|
||||
if not layer:
|
||||
return self.response_404()
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
runtime_data = body.get("runtime_data")
|
||||
|
||||
cls = registry.get(layer.type)
|
||||
if not cls:
|
||||
return self.response_400(message=f"Unknown type: {layer.type}")
|
||||
|
||||
try:
|
||||
schema = cls.get_runtime_schema(
|
||||
layer.implementation.configuration, # type: ignore[attr-defined]
|
||||
runtime_data,
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
return self.response_400(message=str(ex))
|
||||
|
||||
return self.response(200, result=schema)
|
||||
|
||||
@expose("/", methods=("POST",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@requires_json
|
||||
def post(self) -> FlaskResponse:
|
||||
"""Create a semantic layer.
|
||||
---
|
||||
post:
|
||||
summary: Create a semantic layer
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
configuration:
|
||||
type: object
|
||||
cache_timeout:
|
||||
type: integer
|
||||
responses:
|
||||
201:
|
||||
description: Semantic layer created
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
"""
|
||||
try:
|
||||
item = self.add_model_schema.load(request.json)
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=error.messages)
|
||||
|
||||
try:
|
||||
new_model = CreateSemanticLayerCommand(item).run()
|
||||
return self.response(201, result={"uuid": str(new_model.uuid)})
|
||||
except SemanticLayerInvalidError as ex:
|
||||
return self.response_422(message=str(ex))
|
||||
except SemanticLayerCreateFailedError as ex:
|
||||
logger.error(
|
||||
"Error creating semantic layer: %s",
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
|
||||
@expose("/<uuid>", methods=("PUT",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@requires_json
|
||||
def put(self, uuid: str) -> FlaskResponse:
|
||||
"""Update a semantic layer.
|
||||
---
|
||||
put:
|
||||
summary: Update a semantic layer
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: string
|
||||
name: uuid
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
configuration:
|
||||
type: object
|
||||
cache_timeout:
|
||||
type: integer
|
||||
responses:
|
||||
200:
|
||||
description: Semantic layer updated
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
"""
|
||||
try:
|
||||
item = self.edit_model_schema.load(request.json)
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=error.messages)
|
||||
|
||||
try:
|
||||
changed_model = UpdateSemanticLayerCommand(uuid, item).run()
|
||||
return self.response(200, result={"uuid": str(changed_model.uuid)})
|
||||
except SemanticLayerNotFoundError:
|
||||
return self.response_404()
|
||||
except SemanticLayerInvalidError as ex:
|
||||
return self.response_422(message=str(ex))
|
||||
except SemanticLayerUpdateFailedError as ex:
|
||||
logger.error(
|
||||
"Error updating semantic layer: %s",
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
|
||||
@expose("/<uuid>", methods=("DELETE",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
def delete(self, uuid: str) -> FlaskResponse:
|
||||
"""Delete a semantic layer.
|
||||
---
|
||||
delete:
|
||||
summary: Delete a semantic layer
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: string
|
||||
name: uuid
|
||||
responses:
|
||||
200:
|
||||
description: Semantic layer deleted
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
"""
|
||||
try:
|
||||
DeleteSemanticLayerCommand(uuid).run()
|
||||
return self.response(200, message="OK")
|
||||
except SemanticLayerNotFoundError:
|
||||
return self.response_404()
|
||||
except SemanticLayerDeleteFailedError as ex:
|
||||
logger.error(
|
||||
"Error deleting semantic layer: %s",
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
|
||||
@expose("/", methods=("GET",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
def get_list(self) -> FlaskResponse:
|
||||
"""List all semantic layers.
|
||||
---
|
||||
get:
|
||||
summary: List all semantic layers
|
||||
responses:
|
||||
200:
|
||||
description: A list of semantic layers
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
"""
|
||||
layers = SemanticLayerDAO.find_all()
|
||||
result = [_serialize_layer(layer) for layer in layers]
|
||||
return self.response(200, result=result)
|
||||
|
||||
@expose("/<uuid>", methods=("GET",))
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
def get(self, uuid: str) -> FlaskResponse:
|
||||
"""Get a single semantic layer.
|
||||
---
|
||||
get:
|
||||
summary: Get a semantic layer by UUID
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: string
|
||||
name: uuid
|
||||
responses:
|
||||
200:
|
||||
description: A semantic layer
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
"""
|
||||
layer = SemanticLayerDAO.find_by_uuid(uuid)
|
||||
if not layer:
|
||||
return self.response_404()
|
||||
return self.response(200, result=_serialize_layer(layer))
|
||||
912
superset/semantic_layers/mapper.py
Normal file
912
superset/semantic_layers/mapper.py
Normal file
@@ -0,0 +1,912 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Functions for mapping `QueryObject` to semantic layers.
|
||||
|
||||
These functions validate and convert a `QueryObject` into one or more `SemanticQuery`,
|
||||
which are then passed to semantic layer implementations for execution, returning a
|
||||
single dataframe.
|
||||
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from time import time
|
||||
from typing import Any, cast, Sequence, TypeGuard
|
||||
|
||||
import isodate
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from superset_core.semantic_layers.types import (
|
||||
AdhocExpression,
|
||||
Dimension,
|
||||
Filter,
|
||||
FilterValues,
|
||||
Grain,
|
||||
Grains,
|
||||
GroupLimit,
|
||||
Metric,
|
||||
Operator,
|
||||
OrderDirection,
|
||||
OrderTuple,
|
||||
PredicateType,
|
||||
SemanticQuery,
|
||||
SemanticResult,
|
||||
)
|
||||
from superset_core.semantic_layers.view import SemanticViewFeature
|
||||
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.common.utils.time_range_utils import get_since_until_from_query_object
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
from superset.constants import NO_TIME_RANGE
|
||||
from superset.models.helpers import QueryResult
|
||||
from superset.superset_typing import AdhocColumn
|
||||
from superset.utils.core import (
|
||||
FilterOperator,
|
||||
QueryObjectFilterClause,
|
||||
TIME_COMPARISON,
|
||||
)
|
||||
from superset.utils.date_parser import get_past_or_future
|
||||
|
||||
|
||||
class ValidatedQueryObjectFilterClause(QueryObjectFilterClause):
|
||||
"""
|
||||
A validated QueryObject filter clause with a string column name.
|
||||
|
||||
The `col` in a `QueryObjectFilterClause` can be either a string (column name) or an
|
||||
adhoc column, but we only support the former in semantic layers.
|
||||
"""
|
||||
|
||||
# overwrite to narrow type; mypy complains about more restrictive typed dicts,
|
||||
# but the alternative would be to redefine the object
|
||||
col: str # type: ignore[misc]
|
||||
op: str # type: ignore[misc]
|
||||
|
||||
|
||||
class ValidatedQueryObject(QueryObject):
|
||||
"""
|
||||
A query object that has a datasource defined.
|
||||
"""
|
||||
|
||||
datasource: BaseDatasource
|
||||
|
||||
# overwrite to narrow type; mypy complains about the assignment since the base type
|
||||
# allows adhoc filters, but we only support validated filters here
|
||||
filter: list[ValidatedQueryObjectFilterClause] # type: ignore[assignment]
|
||||
series_columns: Sequence[str] # type: ignore[assignment]
|
||||
series_limit_metric: str | None
|
||||
|
||||
|
||||
def get_results(query_object: QueryObject) -> QueryResult:
|
||||
"""
|
||||
Run 1+ queries based on `QueryObject` and return the results.
|
||||
|
||||
:param query_object: The QueryObject containing query specifications
|
||||
:return: QueryResult compatible with Superset's query interface
|
||||
"""
|
||||
if not validate_query_object(query_object):
|
||||
raise ValueError("QueryObject must have a datasource defined.")
|
||||
|
||||
# Track execution time
|
||||
start_time = time()
|
||||
|
||||
semantic_view = query_object.datasource.implementation
|
||||
dispatcher = (
|
||||
semantic_view.get_row_count
|
||||
if query_object.is_rowcount
|
||||
else semantic_view.get_table
|
||||
)
|
||||
|
||||
# Step 1: Convert QueryObject to list of SemanticQuery objects
|
||||
# The first query is the main query, subsequent queries are for time offsets
|
||||
queries = map_query_object(query_object)
|
||||
|
||||
# Step 2: Execute the main query (first in the list)
|
||||
main_query = queries[0]
|
||||
main_result = dispatcher(main_query)
|
||||
|
||||
main_df = main_result.results.to_pandas()
|
||||
|
||||
# Collect all requests (SQL queries, HTTP requests, etc.) for troubleshooting
|
||||
all_requests = list(main_result.requests)
|
||||
|
||||
# If no time offsets, return the main result as-is
|
||||
if not query_object.time_offsets or len(queries) <= 1:
|
||||
duration = timedelta(seconds=time() - start_time)
|
||||
return map_semantic_result_to_query_result(
|
||||
main_result,
|
||||
query_object,
|
||||
duration,
|
||||
)
|
||||
|
||||
# Get metric names from the main query
|
||||
# These are the columns that will be renamed with offset suffixes
|
||||
metric_names = [metric.name for metric in main_query.metrics]
|
||||
|
||||
# Join keys are all columns except metrics
|
||||
# These will be used to match rows between main and offset DataFrames
|
||||
join_keys = [col for col in main_df.columns if col not in metric_names]
|
||||
|
||||
# Step 3 & 4: Execute each time offset query and join results
|
||||
for offset_query, time_offset in zip(
|
||||
queries[1:],
|
||||
query_object.time_offsets,
|
||||
strict=False,
|
||||
):
|
||||
# Execute the offset query
|
||||
result = dispatcher(offset_query)
|
||||
|
||||
# Add this query's requests to the collection
|
||||
all_requests.extend(result.requests)
|
||||
|
||||
offset_df = result.results.to_pandas()
|
||||
|
||||
# Handle empty results - add NaN columns directly instead of merging
|
||||
# This avoids dtype mismatch issues with empty DataFrames
|
||||
if offset_df.empty:
|
||||
# Add offset metric columns with NaN values directly to main_df
|
||||
for metric in metric_names:
|
||||
offset_col_name = TIME_COMPARISON.join([metric, time_offset])
|
||||
main_df[offset_col_name] = np.nan
|
||||
else:
|
||||
# Rename metric columns with time offset suffix
|
||||
# Format: "{metric_name}__{time_offset}"
|
||||
# Example: "revenue" -> "revenue__1 week ago"
|
||||
offset_df = offset_df.rename(
|
||||
columns={
|
||||
metric: TIME_COMPARISON.join([metric, time_offset])
|
||||
for metric in metric_names
|
||||
}
|
||||
)
|
||||
|
||||
# Step 5: Perform left join on dimension columns
|
||||
# This preserves all rows from main_df and adds offset metrics
|
||||
# where they match
|
||||
main_df = main_df.merge(
|
||||
offset_df,
|
||||
on=join_keys,
|
||||
how="left",
|
||||
suffixes=("", "__duplicate"),
|
||||
)
|
||||
|
||||
# Clean up any duplicate columns that might have been created
|
||||
# (shouldn't happen with proper join keys, but defensive programming)
|
||||
duplicate_cols = [
|
||||
col for col in main_df.columns if col.endswith("__duplicate")
|
||||
]
|
||||
if duplicate_cols:
|
||||
main_df = main_df.drop(columns=duplicate_cols)
|
||||
|
||||
# Convert final result to QueryResult
|
||||
semantic_result = SemanticResult(
|
||||
requests=all_requests,
|
||||
results=pa.Table.from_pandas(main_df),
|
||||
)
|
||||
duration = timedelta(seconds=time() - start_time)
|
||||
return map_semantic_result_to_query_result(
|
||||
semantic_result,
|
||||
query_object,
|
||||
duration,
|
||||
)
|
||||
|
||||
|
||||
def map_semantic_result_to_query_result(
|
||||
semantic_result: SemanticResult,
|
||||
query_object: ValidatedQueryObject,
|
||||
duration: timedelta,
|
||||
) -> QueryResult:
|
||||
"""
|
||||
Convert a SemanticResult to a QueryResult.
|
||||
|
||||
:param semantic_result: Result from the semantic layer
|
||||
:param query_object: Original QueryObject (for passthrough attributes)
|
||||
:param duration: Time taken to execute the query
|
||||
:return: QueryResult compatible with Superset's query interface
|
||||
"""
|
||||
# Get the query string from requests (typically one or more SQL queries)
|
||||
query_str = ""
|
||||
if semantic_result.requests:
|
||||
# Join all requests for display (could be multiple for time comparisons)
|
||||
query_str = "\n\n".join(
|
||||
f"-- {req.type}\n{req.definition}" for req in semantic_result.requests
|
||||
)
|
||||
|
||||
return QueryResult(
|
||||
# Core data
|
||||
df=semantic_result.results.to_pandas(),
|
||||
query=query_str,
|
||||
duration=duration,
|
||||
# Template filters - not applicable to semantic layers
|
||||
# (semantic layers don't use Jinja templates)
|
||||
applied_template_filters=None,
|
||||
# Filter columns - not applicable to semantic layers
|
||||
# (semantic layers handle filter validation internally)
|
||||
applied_filter_columns=None,
|
||||
rejected_filter_columns=None,
|
||||
# Status - always success if we got here
|
||||
# (errors would raise exceptions before reaching this point)
|
||||
status=QueryStatus.SUCCESS,
|
||||
error_message=None,
|
||||
errors=None,
|
||||
# Time range - pass through from original query_object
|
||||
from_dttm=query_object.from_dttm,
|
||||
to_dttm=query_object.to_dttm,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_column(column: str | AdhocColumn, dimension_names: set[str]) -> str:
|
||||
"""
|
||||
Normalize a column to its dimension name.
|
||||
|
||||
Columns can be either:
|
||||
- A string (dimension name directly)
|
||||
- An AdhocColumn with isColumnReference=True and sqlExpression containing the
|
||||
dimension name
|
||||
"""
|
||||
if isinstance(column, str):
|
||||
return column
|
||||
|
||||
# Handle column references (e.g., from time-series charts)
|
||||
if column.get("isColumnReference") and (sql_expr := column.get("sqlExpression")):
|
||||
if sql_expr in dimension_names:
|
||||
return sql_expr
|
||||
|
||||
raise ValueError("Adhoc dimensions are not supported in Semantic Views.")
|
||||
|
||||
|
||||
def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
|
||||
"""
|
||||
Convert a `QueryObject` into a list of `SemanticQuery`.
|
||||
|
||||
This function maps the `QueryObject` into query objects that focus less on
|
||||
visualization and more on semantics.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
|
||||
all_metrics = {metric.name: metric for metric in semantic_view.metrics}
|
||||
all_dimensions = {
|
||||
dimension.name: dimension for dimension in semantic_view.dimensions
|
||||
}
|
||||
|
||||
# Normalize columns (may be dicts with isColumnReference=True for time-series)
|
||||
dimension_names = set(all_dimensions.keys())
|
||||
normalized_columns = {
|
||||
_normalize_column(column, dimension_names) for column in query_object.columns
|
||||
}
|
||||
|
||||
metrics = [all_metrics[metric] for metric in (query_object.metrics or [])]
|
||||
|
||||
grain = (
|
||||
_convert_time_grain(query_object.extras["time_grain_sqla"])
|
||||
if "time_grain_sqla" in query_object.extras
|
||||
else None
|
||||
)
|
||||
dimensions = [
|
||||
dimension
|
||||
for dimension in semantic_view.dimensions
|
||||
if dimension.name in normalized_columns
|
||||
and (
|
||||
# if a grain is specified, only include the time dimension if its grain
|
||||
# matches the requested grain
|
||||
grain is None
|
||||
or dimension.name != query_object.granularity
|
||||
or dimension.grain == grain
|
||||
)
|
||||
]
|
||||
|
||||
order = _get_order_from_query_object(query_object, all_metrics, all_dimensions)
|
||||
limit = query_object.row_limit
|
||||
offset = query_object.row_offset
|
||||
|
||||
group_limit = _get_group_limit_from_query_object(
|
||||
query_object,
|
||||
all_metrics,
|
||||
all_dimensions,
|
||||
)
|
||||
|
||||
queries = []
|
||||
for time_offset in [None] + query_object.time_offsets:
|
||||
filters = _get_filters_from_query_object(
|
||||
query_object,
|
||||
time_offset,
|
||||
all_dimensions,
|
||||
)
|
||||
queries.append(
|
||||
SemanticQuery(
|
||||
metrics=metrics,
|
||||
dimensions=dimensions,
|
||||
filters=filters,
|
||||
order=order,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
group_limit=group_limit,
|
||||
)
|
||||
)
|
||||
|
||||
return queries
|
||||
|
||||
|
||||
def _get_filters_from_query_object(
|
||||
query_object: ValidatedQueryObject,
|
||||
time_offset: str | None,
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> set[Filter]:
|
||||
"""
|
||||
Extract all filters from the query object, including time range filters.
|
||||
|
||||
This simplifies the complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm
|
||||
by converting all time constraints into filters.
|
||||
"""
|
||||
filters: set[Filter] = set()
|
||||
|
||||
# 1. Add fetch values predicate if present
|
||||
if (
|
||||
query_object.apply_fetch_values_predicate
|
||||
and query_object.datasource.fetch_values_predicate
|
||||
):
|
||||
filters.add(
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=None,
|
||||
operator=Operator.ADHOC,
|
||||
value=query_object.datasource.fetch_values_predicate,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Add time range filter based on from_dttm/to_dttm
|
||||
# For time offsets, this automatically calculates the shifted bounds
|
||||
time_filters = _get_time_filter(query_object, time_offset, all_dimensions)
|
||||
filters.update(time_filters)
|
||||
|
||||
# 3. Add filters from query_object.extras (WHERE and HAVING clauses)
|
||||
extras_filters = _get_filters_from_extras(query_object.extras)
|
||||
filters.update(extras_filters)
|
||||
|
||||
# 4. Add all other filters from query_object.filter
|
||||
for filter_ in query_object.filter:
|
||||
# Skip temporal range filters - we're using inner bounds instead
|
||||
if (
|
||||
filter_.get("op") == FilterOperator.TEMPORAL_RANGE.value
|
||||
and query_object.granularity
|
||||
):
|
||||
continue
|
||||
|
||||
if converted_filters := _convert_query_object_filter(filter_, all_dimensions):
|
||||
filters.update(converted_filters)
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
def _get_filters_from_extras(extras: dict[str, Any]) -> set[Filter]:
|
||||
"""
|
||||
Extract filters from the extras dict.
|
||||
|
||||
The extras dict can contain various keys that affect query behavior:
|
||||
|
||||
Supported keys (converted to filters):
|
||||
- "where": SQL WHERE clause expression (e.g., "customer_id > 100")
|
||||
- "having": SQL HAVING clause expression (e.g., "SUM(sales) > 1000")
|
||||
|
||||
Other keys in extras (handled elsewhere in the mapper):
|
||||
- "time_grain_sqla": Time granularity (e.g., "P1D", "PT1H")
|
||||
Handled in _convert_time_grain() and used for dimension grain matching
|
||||
|
||||
Note: The WHERE and HAVING clauses from extras are SQL expressions that
|
||||
are passed through as-is to the semantic layer as adhoc Filter objects.
|
||||
"""
|
||||
filters: set[Filter] = set()
|
||||
|
||||
# Add WHERE clause from extras
|
||||
if where_clause := extras.get("where"):
|
||||
filters.add(
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=None,
|
||||
operator=Operator.ADHOC,
|
||||
value=where_clause,
|
||||
)
|
||||
)
|
||||
|
||||
# Add HAVING clause from extras
|
||||
if having_clause := extras.get("having"):
|
||||
filters.add(
|
||||
Filter(
|
||||
type=PredicateType.HAVING,
|
||||
column=None,
|
||||
operator=Operator.ADHOC,
|
||||
value=having_clause,
|
||||
)
|
||||
)
|
||||
|
||||
return filters
|
||||
|
||||
|
||||
def _get_time_filter(
|
||||
query_object: ValidatedQueryObject,
|
||||
time_offset: str | None,
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> set[Filter]:
|
||||
"""
|
||||
Create a time range filter from the query object.
|
||||
|
||||
This handles both regular queries and time offset queries, simplifying the
|
||||
complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm by using the
|
||||
same time bounds for both the main query and series limit subqueries.
|
||||
"""
|
||||
filters: set[Filter] = set()
|
||||
|
||||
if not query_object.granularity:
|
||||
return filters
|
||||
|
||||
time_dimension = all_dimensions.get(query_object.granularity)
|
||||
if not time_dimension:
|
||||
return filters
|
||||
|
||||
# Get the appropriate time bounds based on whether this is a time offset query
|
||||
from_dttm, to_dttm = _get_time_bounds(query_object, time_offset)
|
||||
|
||||
if not from_dttm or not to_dttm:
|
||||
return filters
|
||||
|
||||
# Create a filter with >= and < operators
|
||||
return {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=time_dimension,
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=from_dttm,
|
||||
),
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=time_dimension,
|
||||
operator=Operator.LESS_THAN,
|
||||
value=to_dttm,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _get_time_bounds(
|
||||
query_object: ValidatedQueryObject,
|
||||
time_offset: str | None,
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
Get the appropriate time bounds for the query.
|
||||
|
||||
For regular queries (time_offset is None), returns from_dttm/to_dttm.
|
||||
For time offset queries, calculates the shifted bounds.
|
||||
|
||||
This simplifies the inner_from_dttm/inner_to_dttm complexity by using
|
||||
the same bounds for both main queries and series limit subqueries (Option 1).
|
||||
"""
|
||||
if time_offset is None:
|
||||
# Main query: use from_dttm/to_dttm directly
|
||||
return query_object.from_dttm, query_object.to_dttm
|
||||
|
||||
# Time offset query: calculate shifted bounds
|
||||
# Use from_dttm/to_dttm if available, otherwise try to get from time_range
|
||||
outer_from = query_object.from_dttm
|
||||
outer_to = query_object.to_dttm
|
||||
|
||||
if not outer_from or not outer_to:
|
||||
# Fall back to parsing time_range if from_dttm/to_dttm not set
|
||||
outer_from, outer_to = get_since_until_from_query_object(query_object)
|
||||
|
||||
if not outer_from or not outer_to:
|
||||
return None, None
|
||||
|
||||
# Apply the offset to both bounds
|
||||
offset_from = get_past_or_future(time_offset, outer_from)
|
||||
offset_to = get_past_or_future(time_offset, outer_to)
|
||||
|
||||
return offset_from, offset_to
|
||||
|
||||
|
||||
def _convert_query_object_filter(
|
||||
filter_: ValidatedQueryObjectFilterClause,
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> set[Filter] | None:
|
||||
"""
|
||||
Convert a QueryObject filter dict to a semantic layer Filter.
|
||||
"""
|
||||
operator_str = filter_["op"]
|
||||
|
||||
# Handle simple column filters
|
||||
col = filter_.get("col")
|
||||
if col not in all_dimensions:
|
||||
return None
|
||||
|
||||
dimension = all_dimensions[col]
|
||||
|
||||
val_str = filter_["val"]
|
||||
value: FilterValues | frozenset[FilterValues]
|
||||
if val_str is None:
|
||||
value = None
|
||||
elif isinstance(val_str, (list, tuple)):
|
||||
value = frozenset(val_str)
|
||||
else:
|
||||
value = val_str
|
||||
|
||||
# Special case for temporal range
|
||||
if operator_str == FilterOperator.TEMPORAL_RANGE.value:
|
||||
if not isinstance(value, str) or value == NO_TIME_RANGE:
|
||||
return None
|
||||
start, end = value.split(" : ")
|
||||
return {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=dimension,
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=start,
|
||||
),
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=dimension,
|
||||
operator=Operator.LESS_THAN,
|
||||
value=end,
|
||||
),
|
||||
}
|
||||
|
||||
# Map QueryObject operators to semantic layer operators
|
||||
operator_mapping = {
|
||||
FilterOperator.EQUALS.value: Operator.EQUALS,
|
||||
FilterOperator.NOT_EQUALS.value: Operator.NOT_EQUALS,
|
||||
FilterOperator.GREATER_THAN.value: Operator.GREATER_THAN,
|
||||
FilterOperator.LESS_THAN.value: Operator.LESS_THAN,
|
||||
FilterOperator.GREATER_THAN_OR_EQUALS.value: Operator.GREATER_THAN_OR_EQUAL,
|
||||
FilterOperator.LESS_THAN_OR_EQUALS.value: Operator.LESS_THAN_OR_EQUAL,
|
||||
FilterOperator.IN.value: Operator.IN,
|
||||
FilterOperator.NOT_IN.value: Operator.NOT_IN,
|
||||
FilterOperator.LIKE.value: Operator.LIKE,
|
||||
FilterOperator.NOT_LIKE.value: Operator.NOT_LIKE,
|
||||
FilterOperator.IS_NULL.value: Operator.IS_NULL,
|
||||
FilterOperator.IS_NOT_NULL.value: Operator.IS_NOT_NULL,
|
||||
}
|
||||
|
||||
operator = operator_mapping.get(operator_str)
|
||||
if not operator:
|
||||
# Unknown operator - raise error to prevent unauthorized access
|
||||
raise ValueError(f"Unsupported filter operator: {operator_str}")
|
||||
|
||||
return {
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=dimension,
|
||||
operator=operator,
|
||||
value=value,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _get_order_from_query_object(
|
||||
query_object: ValidatedQueryObject,
|
||||
all_metrics: dict[str, Metric],
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> list[OrderTuple]:
|
||||
order: list[OrderTuple] = []
|
||||
for element, ascending in query_object.orderby:
|
||||
direction = OrderDirection.ASC if ascending else OrderDirection.DESC
|
||||
|
||||
# adhoc
|
||||
if isinstance(element, dict):
|
||||
if element["sqlExpression"] is not None:
|
||||
order.append(
|
||||
(
|
||||
AdhocExpression(
|
||||
id=element["label"] or element["sqlExpression"],
|
||||
definition=element["sqlExpression"],
|
||||
),
|
||||
direction,
|
||||
)
|
||||
)
|
||||
elif element in all_dimensions:
|
||||
order.append((all_dimensions[element], direction))
|
||||
elif element in all_metrics:
|
||||
order.append((all_metrics[element], direction))
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def _get_group_limit_from_query_object(
|
||||
query_object: ValidatedQueryObject,
|
||||
all_metrics: dict[str, Metric],
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> GroupLimit | None:
|
||||
# no limit
|
||||
if query_object.series_limit == 0 or not query_object.columns:
|
||||
return None
|
||||
|
||||
dimensions = [all_dimensions[dim_id] for dim_id in query_object.series_columns]
|
||||
top = query_object.series_limit
|
||||
metric = (
|
||||
all_metrics[query_object.series_limit_metric]
|
||||
if query_object.series_limit_metric
|
||||
else None
|
||||
)
|
||||
direction = OrderDirection.DESC if query_object.order_desc else OrderDirection.ASC
|
||||
group_others = query_object.group_others_when_limit_reached
|
||||
|
||||
# Check if we need separate filters for the group limit subquery
|
||||
# This happens when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm
|
||||
group_limit_filters = _get_group_limit_filters(query_object, all_dimensions)
|
||||
|
||||
return GroupLimit(
|
||||
dimensions=dimensions,
|
||||
top=top,
|
||||
metric=metric,
|
||||
direction=direction,
|
||||
group_others=group_others,
|
||||
filters=group_limit_filters,
|
||||
)
|
||||
|
||||
|
||||
def _get_group_limit_filters(
|
||||
query_object: ValidatedQueryObject,
|
||||
all_dimensions: dict[str, Dimension],
|
||||
) -> set[Filter] | None:
|
||||
"""
|
||||
Get separate filters for the group limit subquery if needed.
|
||||
|
||||
This is used when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm,
|
||||
which happens during time comparison queries. The group limit subquery may need
|
||||
different time bounds to determine the top N groups.
|
||||
|
||||
Returns None if the group limit should use the same filters as the main query.
|
||||
"""
|
||||
# Check if inner time bounds are explicitly set and differ from outer bounds
|
||||
if (
|
||||
query_object.inner_from_dttm is None
|
||||
or query_object.inner_to_dttm is None
|
||||
or (
|
||||
query_object.inner_from_dttm == query_object.from_dttm
|
||||
and query_object.inner_to_dttm == query_object.to_dttm
|
||||
)
|
||||
):
|
||||
# No separate bounds needed - use the same filters as the main query
|
||||
return None
|
||||
|
||||
# Create separate filters for the group limit subquery
|
||||
filters: set[Filter] = set()
|
||||
|
||||
# Add time range filter using inner bounds
|
||||
if query_object.granularity:
|
||||
time_dimension = all_dimensions.get(query_object.granularity)
|
||||
if (
|
||||
time_dimension
|
||||
and query_object.inner_from_dttm
|
||||
and query_object.inner_to_dttm
|
||||
):
|
||||
filters.update(
|
||||
{
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=time_dimension,
|
||||
operator=Operator.GREATER_THAN_OR_EQUAL,
|
||||
value=query_object.inner_from_dttm,
|
||||
),
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=time_dimension,
|
||||
operator=Operator.LESS_THAN,
|
||||
value=query_object.inner_to_dttm,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Add fetch values predicate if present
|
||||
if (
|
||||
query_object.apply_fetch_values_predicate
|
||||
and query_object.datasource.fetch_values_predicate
|
||||
):
|
||||
filters.add(
|
||||
Filter(
|
||||
type=PredicateType.WHERE,
|
||||
column=None,
|
||||
operator=Operator.ADHOC,
|
||||
value=query_object.datasource.fetch_values_predicate,
|
||||
)
|
||||
)
|
||||
|
||||
# Add filters from query_object.extras (WHERE and HAVING clauses)
|
||||
extras_filters = _get_filters_from_extras(query_object.extras)
|
||||
filters.update(extras_filters)
|
||||
|
||||
# Add all other non-temporal filters from query_object.filter
|
||||
for filter_ in query_object.filter:
|
||||
# Skip temporal range filters - we're using inner bounds instead
|
||||
if (
|
||||
filter_.get("op") == FilterOperator.TEMPORAL_RANGE.value
|
||||
and query_object.granularity
|
||||
):
|
||||
continue
|
||||
|
||||
if converted_filters := _convert_query_object_filter(filter_, all_dimensions):
|
||||
filters.update(converted_filters)
|
||||
|
||||
return filters if filters else None
|
||||
|
||||
|
||||
def _convert_time_grain(time_grain: str) -> Grain | None:
|
||||
"""
|
||||
Convert a time grain string (ISO 8601 duration) to a Grain instance.
|
||||
"""
|
||||
try:
|
||||
return Grains.get(time_grain)
|
||||
except (ValueError, isodate.ISO8601Error):
|
||||
return None
|
||||
|
||||
|
||||
def validate_query_object(
|
||||
query_object: QueryObject,
|
||||
) -> TypeGuard[ValidatedQueryObject]:
|
||||
"""
|
||||
Validate that the `QueryObject` is compatible with the `SemanticView`.
|
||||
|
||||
If some semantic view implementation supports these features we should add an
|
||||
attribute to the `SemanticViewImplementation` to indicate support for them.
|
||||
"""
|
||||
if not query_object.datasource:
|
||||
return False
|
||||
|
||||
query_object = cast(ValidatedQueryObject, query_object)
|
||||
|
||||
_validate_metrics(query_object)
|
||||
_validate_dimensions(query_object)
|
||||
_validate_filters(query_object)
|
||||
_validate_granularity(query_object)
|
||||
_validate_group_limit(query_object)
|
||||
_validate_orderby(query_object)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _validate_metrics(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Make sure metrics are defined in the semantic view.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
|
||||
if any(not isinstance(metric, str) for metric in (query_object.metrics or [])):
|
||||
raise ValueError("Adhoc metrics are not supported in Semantic Views.")
|
||||
|
||||
metric_names = {metric.name for metric in semantic_view.metrics}
|
||||
if not set(query_object.metrics or []) <= metric_names:
|
||||
raise ValueError("All metrics must be defined in the Semantic View.")
|
||||
|
||||
|
||||
def _validate_dimensions(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Make sure all dimensions are defined in the semantic view.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
|
||||
|
||||
# Normalize all columns to dimension names
|
||||
normalized_columns = [
|
||||
_normalize_column(column, dimension_names) for column in query_object.columns
|
||||
]
|
||||
|
||||
if not set(normalized_columns) <= dimension_names:
|
||||
raise ValueError("All dimensions must be defined in the Semantic View.")
|
||||
|
||||
|
||||
def _validate_filters(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Make sure all filters are valid.
|
||||
"""
|
||||
for filter_ in query_object.filter:
|
||||
if isinstance(filter_["col"], dict):
|
||||
raise ValueError(
|
||||
"Adhoc columns are not supported in Semantic View filters."
|
||||
)
|
||||
if not filter_.get("op"):
|
||||
raise ValueError("All filters must have an operator defined.")
|
||||
|
||||
|
||||
def _validate_granularity(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Make sure time column and time grain are valid.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
|
||||
|
||||
if time_column := query_object.granularity:
|
||||
if time_column not in dimension_names:
|
||||
raise ValueError(
|
||||
"The time column must be defined in the Semantic View dimensions."
|
||||
)
|
||||
|
||||
if time_grain := query_object.extras.get("time_grain_sqla"):
|
||||
if not time_column:
|
||||
raise ValueError(
|
||||
"A time column must be specified when a time grain is provided."
|
||||
)
|
||||
|
||||
supported_time_grains = {
|
||||
dimension.grain
|
||||
for dimension in semantic_view.dimensions
|
||||
if dimension.name == time_column and dimension.grain
|
||||
}
|
||||
if _convert_time_grain(time_grain) not in supported_time_grains:
|
||||
raise ValueError(
|
||||
"The time grain is not supported for the time column in the "
|
||||
"Semantic View."
|
||||
)
|
||||
|
||||
|
||||
def _validate_group_limit(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Validate group limit related features in the query object.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
|
||||
# no limit
|
||||
if query_object.series_limit == 0:
|
||||
return
|
||||
|
||||
if (
|
||||
query_object.series_columns
|
||||
and SemanticViewFeature.GROUP_LIMIT not in semantic_view.features
|
||||
):
|
||||
raise ValueError("Group limit is not supported in this Semantic View.")
|
||||
|
||||
if any(not isinstance(col, str) for col in query_object.series_columns):
|
||||
raise ValueError("Adhoc dimensions are not supported in series columns.")
|
||||
|
||||
metric_names = {metric.name for metric in semantic_view.metrics}
|
||||
if query_object.series_limit_metric and (
|
||||
not isinstance(query_object.series_limit_metric, str)
|
||||
or query_object.series_limit_metric not in metric_names
|
||||
):
|
||||
raise ValueError(
|
||||
"The series limit metric must be defined in the Semantic View."
|
||||
)
|
||||
|
||||
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
|
||||
if not set(query_object.series_columns) <= dimension_names:
|
||||
raise ValueError("All series columns must be defined in the Semantic View.")
|
||||
|
||||
if (
|
||||
query_object.group_others_when_limit_reached
|
||||
and SemanticViewFeature.GROUP_OTHERS not in semantic_view.features
|
||||
):
|
||||
raise ValueError(
|
||||
"Grouping others when limit is reached is not supported in this Semantic "
|
||||
"View."
|
||||
)
|
||||
|
||||
|
||||
def _validate_orderby(query_object: ValidatedQueryObject) -> None:
|
||||
"""
|
||||
Validate order by elements in the query object.
|
||||
"""
|
||||
semantic_view = query_object.datasource.implementation
|
||||
|
||||
if (
|
||||
any(not isinstance(element, str) for element, _ in query_object.orderby)
|
||||
and SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY
|
||||
not in semantic_view.features
|
||||
):
|
||||
raise ValueError(
|
||||
"Adhoc expressions in order by are not supported in this Semantic View."
|
||||
)
|
||||
|
||||
elements = {orderby[0] for orderby in query_object.orderby}
|
||||
metric_names = {metric.name for metric in semantic_view.metrics}
|
||||
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
|
||||
if not elements <= metric_names | dimension_names:
|
||||
raise ValueError("All order by elements must be defined in the Semantic View.")
|
||||
408
superset/semantic_layers/models.py
Normal file
408
superset/semantic_layers/models.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Semantic layer models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Hashable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy_utils import UUIDType
|
||||
from sqlalchemy_utils.types.json import JSONType
|
||||
from superset_core.semantic_layers.layer import (
|
||||
SemanticLayer as SemanticLayerABC,
|
||||
)
|
||||
from superset_core.semantic_layers.view import (
|
||||
SemanticView as SemanticViewABC,
|
||||
)
|
||||
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.explorables.base import TimeGrainDict
|
||||
from superset.extensions import encrypted_field_factory
|
||||
from superset.models.helpers import AuditMixinNullable, QueryResult
|
||||
from superset.semantic_layers.mapper import get_results
|
||||
from superset.semantic_layers.registry import registry
|
||||
from superset.utils import json
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.superset_typing import ExplorableData, QueryObjectDict
|
||||
|
||||
|
||||
def get_column_type(semantic_type: pa.DataType) -> GenericDataType:
|
||||
"""
|
||||
Map Arrow data types to generic data types.
|
||||
"""
|
||||
if pa.types.is_date(semantic_type) or pa.types.is_timestamp(semantic_type):
|
||||
return GenericDataType.TEMPORAL
|
||||
if pa.types.is_time(semantic_type):
|
||||
return GenericDataType.TEMPORAL
|
||||
if (
|
||||
pa.types.is_integer(semantic_type)
|
||||
or pa.types.is_floating(semantic_type)
|
||||
or pa.types.is_decimal(semantic_type)
|
||||
or pa.types.is_duration(semantic_type)
|
||||
):
|
||||
return GenericDataType.NUMERIC
|
||||
if pa.types.is_boolean(semantic_type):
|
||||
return GenericDataType.BOOLEAN
|
||||
return GenericDataType.STRING
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetricMetadata:
|
||||
metric_name: str
|
||||
expression: str
|
||||
verbose_name: str | None = None
|
||||
description: str | None = None
|
||||
d3format: str | None = None
|
||||
currency: dict[str, Any] | None = None
|
||||
warning_text: str | None = None
|
||||
certified_by: str | None = None
|
||||
certification_details: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ColumnMetadata:
|
||||
column_name: str
|
||||
type: str
|
||||
is_dttm: bool
|
||||
verbose_name: str | None = None
|
||||
description: str | None = None
|
||||
groupby: bool = True
|
||||
filterable: bool = True
|
||||
expression: str | None = None
|
||||
python_date_format: str | None = None
|
||||
advanced_data_type: str | None = None
|
||||
extra: str | None = None
|
||||
|
||||
|
||||
class SemanticLayer(AuditMixinNullable, Model):
|
||||
"""
|
||||
Semantic layer model.
|
||||
|
||||
A semantic layer provides an abstraction over data sources,
|
||||
allowing users to query data through a semantic interface.
|
||||
"""
|
||||
|
||||
__tablename__ = "semantic_layers"
|
||||
|
||||
uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Core fields
|
||||
name = Column(String(250), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
type = Column(String(250), nullable=False) # snowflake, etc
|
||||
|
||||
configuration = Column(encrypted_field_factory.create(JSONType), default="{}")
|
||||
# Tracks the schema version of the configuration JSON field to aid with
|
||||
# migrations as the configuration schema evolves over time.
|
||||
configuration_version = Column(Integer, nullable=False, default=1)
|
||||
cache_timeout = Column(Integer, nullable=True)
|
||||
|
||||
# Semantic views relationship
|
||||
semantic_views: list[SemanticView] = relationship(
|
||||
"SemanticView",
|
||||
back_populates="semantic_layer",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.name or str(self.uuid)
|
||||
|
||||
@cached_property
|
||||
def implementation(
|
||||
self,
|
||||
) -> SemanticLayerABC[Any, SemanticViewABC]:
|
||||
"""
|
||||
Return semantic layer implementation.
|
||||
"""
|
||||
# TODO (betodealmeida):
|
||||
# return extension_manager.get_contribution("semanticLayers", self.type)
|
||||
class_ = registry[self.type]
|
||||
return class_.from_configuration(json.loads(self.configuration))
|
||||
|
||||
|
||||
class SemanticView(AuditMixinNullable, Model):
|
||||
"""
|
||||
Semantic view model.
|
||||
|
||||
A semantic view represents a queryable view within a semantic layer.
|
||||
"""
|
||||
|
||||
__tablename__ = "semantic_views"
|
||||
|
||||
# Use integer as the primary key for cross-database auto-increment
|
||||
# compatibility (sa.Identity() is not supported in MySQL or SQLite).
|
||||
# The uuid column is a secondary unique identifier used in URLs and perms.
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
uuid = Column(UUIDType(binary=True), unique=True, default=uuid.uuid4)
|
||||
|
||||
# Core fields
|
||||
name = Column(String(250), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
configuration = Column(encrypted_field_factory.create(JSONType), default="{}")
|
||||
# Tracks the schema version of the configuration JSON field to aid with
|
||||
# migrations as the configuration schema evolves over time.
|
||||
configuration_version = Column(Integer, nullable=False, default=1)
|
||||
cache_timeout = Column(Integer, nullable=True)
|
||||
|
||||
# Semantic layer relationship
|
||||
semantic_layer_uuid = Column(
|
||||
UUIDType(binary=True),
|
||||
ForeignKey("semantic_layers.uuid", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
semantic_layer: SemanticLayer = relationship(
|
||||
"SemanticLayer",
|
||||
back_populates="semantic_views",
|
||||
foreign_keys=[semantic_layer_uuid],
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.name or str(self.uuid)
|
||||
|
||||
@cached_property
|
||||
def implementation(self) -> SemanticViewABC:
|
||||
"""
|
||||
Return semantic view implementation.
|
||||
"""
|
||||
return self.semantic_layer.implementation.get_semantic_view(
|
||||
self.name,
|
||||
json.loads(self.configuration),
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Explorable protocol implementation
|
||||
# =========================================================================
|
||||
|
||||
def get_query_result(self, query_object: QueryObject) -> QueryResult:
|
||||
return get_results(query_object)
|
||||
|
||||
def get_query_str(self, query_obj: QueryObjectDict) -> str:
|
||||
return "Not implemented for semantic layers"
|
||||
|
||||
@property
|
||||
def table_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def kind(self) -> str:
|
||||
return "semantic_view"
|
||||
|
||||
@property
|
||||
def uid(self) -> str:
|
||||
return self.implementation.uid()
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "semantic_view"
|
||||
|
||||
@property
|
||||
def metrics(self) -> list[MetricMetadata]:
|
||||
return [
|
||||
MetricMetadata(
|
||||
metric_name=metric.name,
|
||||
expression=metric.definition,
|
||||
description=metric.description,
|
||||
)
|
||||
for metric in self.implementation.get_metrics()
|
||||
]
|
||||
|
||||
@property
|
||||
def columns(self) -> list[ColumnMetadata]:
|
||||
return [
|
||||
ColumnMetadata(
|
||||
column_name=dimension.name,
|
||||
type=str(dimension.type),
|
||||
is_dttm=pa.types.is_date(dimension.type)
|
||||
or pa.types.is_time(dimension.type)
|
||||
or pa.types.is_timestamp(dimension.type),
|
||||
description=dimension.description,
|
||||
expression=dimension.definition,
|
||||
extra=json.dumps(
|
||||
{"grain": dimension.grain.name if dimension.grain else None}
|
||||
),
|
||||
)
|
||||
for dimension in self.implementation.get_dimensions()
|
||||
]
|
||||
|
||||
@property
|
||||
def column_names(self) -> list[str]:
|
||||
return [dimension.name for dimension in self.implementation.get_dimensions()]
|
||||
|
||||
@property
|
||||
def data(self) -> ExplorableData:
|
||||
return {
|
||||
# core
|
||||
"id": self.id,
|
||||
"uid": self.uid,
|
||||
"type": "semantic_view",
|
||||
"name": self.name,
|
||||
"columns": [
|
||||
{
|
||||
"advanced_data_type": None,
|
||||
"certification_details": None,
|
||||
"certified_by": None,
|
||||
"column_name": dimension.name,
|
||||
"description": dimension.description,
|
||||
"expression": dimension.definition,
|
||||
"filterable": True,
|
||||
"groupby": True,
|
||||
"id": None,
|
||||
"uuid": None,
|
||||
"is_certified": False,
|
||||
"is_dttm": pa.types.is_date(dimension.type)
|
||||
or pa.types.is_time(dimension.type)
|
||||
or pa.types.is_timestamp(dimension.type),
|
||||
"python_date_format": None,
|
||||
"type": str(dimension.type),
|
||||
"type_generic": get_column_type(dimension.type),
|
||||
"verbose_name": None,
|
||||
"warning_markdown": None,
|
||||
}
|
||||
for dimension in self.implementation.get_dimensions()
|
||||
],
|
||||
"metrics": [
|
||||
{
|
||||
"certification_details": None,
|
||||
"certified_by": None,
|
||||
"d3format": None,
|
||||
"description": metric.description,
|
||||
"expression": metric.definition,
|
||||
"id": None,
|
||||
"uuid": None,
|
||||
"is_certified": False,
|
||||
"metric_name": metric.name,
|
||||
"warning_markdown": None,
|
||||
"warning_text": None,
|
||||
"verbose_name": None,
|
||||
}
|
||||
for metric in self.implementation.get_metrics()
|
||||
],
|
||||
"database": {},
|
||||
# UI features
|
||||
"verbose_map": {},
|
||||
"order_by_choices": [],
|
||||
"filter_select": True,
|
||||
"filter_select_enabled": True,
|
||||
"sql": None,
|
||||
"select_star": None,
|
||||
"owners": [],
|
||||
"description": self.description,
|
||||
"table_name": self.name,
|
||||
"column_types": [
|
||||
get_column_type(dimension.type)
|
||||
for dimension in self.implementation.get_dimensions()
|
||||
],
|
||||
"column_names": [
|
||||
dimension.name for dimension in self.implementation.get_dimensions()
|
||||
],
|
||||
# rare
|
||||
"column_formats": {},
|
||||
"datasource_name": self.name,
|
||||
"perm": self.perm,
|
||||
"offset": self.offset,
|
||||
"cache_timeout": self.cache_timeout,
|
||||
"params": None,
|
||||
# sql-specific
|
||||
"schema": None,
|
||||
"catalog": None,
|
||||
"main_dttm_col": None,
|
||||
"time_grain_sqla": [],
|
||||
"granularity_sqla": [],
|
||||
"fetch_values_predicate": None,
|
||||
"template_params": None,
|
||||
"is_sqllab_view": False,
|
||||
"extra": None,
|
||||
"always_filter_main_dttm": False,
|
||||
"normalize_columns": False,
|
||||
"edit_url": "",
|
||||
"default_endpoint": None,
|
||||
"folders": [],
|
||||
"health_check_message": None,
|
||||
}
|
||||
|
||||
def data_for_slices(self, slices: list[Any]) -> ExplorableData:
|
||||
return self.data
|
||||
|
||||
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def perm(self) -> str:
|
||||
return self.semantic_layer_uuid.hex + "::" + self.uuid.hex
|
||||
|
||||
@property
|
||||
def catalog_perm(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def schema_perm(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def schema(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"/semantic_view/{self.uuid}/"
|
||||
|
||||
@property
|
||||
def explore_url(self) -> str:
|
||||
return f"/explore/?datasource_type=semantic_view&datasource_id={self.id}"
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
# always return datetime as UTC
|
||||
return 0
|
||||
|
||||
def get_time_grains(self) -> list[TimeGrainDict]:
|
||||
return [
|
||||
{
|
||||
"name": dimension.grain.name,
|
||||
"function": "",
|
||||
"duration": dimension.grain.representation,
|
||||
}
|
||||
for dimension in self.implementation.get_dimensions()
|
||||
if dimension.grain
|
||||
]
|
||||
|
||||
def has_drill_by_columns(self, column_names: list[str]) -> bool:
|
||||
dimension_names = {
|
||||
dimension.name for dimension in self.implementation.get_dimensions()
|
||||
}
|
||||
return all(column_name in dimension_names for column_name in column_names)
|
||||
|
||||
@property
|
||||
def is_rls_supported(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def query_language(self) -> str | None:
|
||||
return None
|
||||
24
superset/semantic_layers/registry.py
Normal file
24
superset/semantic_layers/registry.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from superset_core.semantic_layers.layer import SemanticLayer
|
||||
|
||||
registry: dict[str, type[SemanticLayer[Any, Any]]] = {}
|
||||
37
superset/semantic_layers/schemas.py
Normal file
37
superset/semantic_layers/schemas.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from marshmallow import fields, Schema
|
||||
|
||||
|
||||
class SemanticViewPutSchema(Schema):
|
||||
description = fields.String(allow_none=True)
|
||||
cache_timeout = fields.Integer(allow_none=True)
|
||||
|
||||
|
||||
class SemanticLayerPostSchema(Schema):
|
||||
name = fields.String(required=True)
|
||||
description = fields.String(allow_none=True)
|
||||
type = fields.String(required=True)
|
||||
configuration = fields.Dict(required=True)
|
||||
cache_timeout = fields.Integer(allow_none=True)
|
||||
|
||||
|
||||
class SemanticLayerPutSchema(Schema):
|
||||
name = fields.String()
|
||||
description = fields.String(allow_none=True)
|
||||
configuration = fields.Dict()
|
||||
cache_timeout = fields.Integer(allow_none=True)
|
||||
@@ -30,6 +30,46 @@ if TYPE_CHECKING:
|
||||
SQLType: TypeAlias = TypeEngine | type[TypeEngine]
|
||||
|
||||
|
||||
class DatasetColumnData(TypedDict, total=False):
|
||||
"""Type for column metadata in ExplorableData datasets."""
|
||||
|
||||
advanced_data_type: str | None
|
||||
certification_details: str | None
|
||||
certified_by: str | None
|
||||
column_name: str
|
||||
description: str | None
|
||||
expression: str | None
|
||||
filterable: bool
|
||||
groupby: bool
|
||||
id: int | None
|
||||
uuid: str | None
|
||||
is_certified: bool
|
||||
is_dttm: bool
|
||||
python_date_format: str | None
|
||||
type: str
|
||||
type_generic: NotRequired["GenericDataType" | None]
|
||||
verbose_name: str | None
|
||||
warning_markdown: str | None
|
||||
|
||||
|
||||
class DatasetMetricData(TypedDict, total=False):
|
||||
"""Type for metric metadata in ExplorableData datasets."""
|
||||
|
||||
certification_details: str | None
|
||||
certified_by: str | None
|
||||
currency: NotRequired[dict[str, Any]]
|
||||
d3format: str | None
|
||||
description: str | None
|
||||
expression: str | None
|
||||
id: int | None
|
||||
uuid: str | None
|
||||
is_certified: bool
|
||||
metric_name: str
|
||||
warning_markdown: str | None
|
||||
warning_text: str | None
|
||||
verbose_name: str | None
|
||||
|
||||
|
||||
class LegacyMetric(TypedDict):
|
||||
label: str | None
|
||||
|
||||
@@ -254,7 +294,7 @@ class ExplorableData(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
# Core fields from BaseDatasource.data
|
||||
id: int
|
||||
id: int | str # String for UUID-based explorables like SemanticView
|
||||
uid: str
|
||||
column_formats: dict[str, str | None]
|
||||
description: str | None
|
||||
@@ -274,8 +314,8 @@ class ExplorableData(TypedDict, total=False):
|
||||
perm: str | None
|
||||
edit_url: str
|
||||
sql: str | None
|
||||
columns: list[dict[str, Any]]
|
||||
metrics: list[dict[str, Any]]
|
||||
columns: list["DatasetColumnData"]
|
||||
metrics: list["DatasetMetricData"]
|
||||
folders: Any # JSON field, can be list or dict
|
||||
order_by_choices: list[tuple[str, str]]
|
||||
owners: list[int] | list[dict[str, Any]] # Can be either format
|
||||
@@ -283,8 +323,8 @@ class ExplorableData(TypedDict, total=False):
|
||||
select_star: str | None
|
||||
|
||||
# Additional fields from SqlaTable and data_for_slices
|
||||
column_types: list[Any]
|
||||
column_names: set[str] | set[Any]
|
||||
column_types: list["GenericDataType"]
|
||||
column_names: set[str] | list[str]
|
||||
granularity_sqla: list[tuple[Any, Any]]
|
||||
time_grain_sqla: list[tuple[Any, Any]]
|
||||
main_dttm_col: str | None
|
||||
|
||||
@@ -96,7 +96,6 @@ from superset.exceptions import (
|
||||
SupersetException,
|
||||
SupersetTimeoutException,
|
||||
)
|
||||
from superset.explorables.base import Explorable
|
||||
from superset.sql.parse import sanitize_clause
|
||||
from superset.superset_typing import (
|
||||
AdhocColumn,
|
||||
@@ -115,7 +114,7 @@ from superset.utils.hashing import hash_from_dict, hash_from_str
|
||||
from superset.utils.pandas import detect_datetime_format
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
from superset.explorables.base import ColumnMetadata, Explorable
|
||||
from superset.models.core import Database
|
||||
|
||||
logging.getLogger("MARKDOWN").setLevel(logging.INFO)
|
||||
@@ -200,6 +199,7 @@ class DatasourceType(StrEnum):
|
||||
QUERY = "query"
|
||||
SAVEDQUERY = "saved_query"
|
||||
VIEW = "view"
|
||||
SEMANTIC_VIEW = "semantic_view"
|
||||
|
||||
|
||||
class LoggerLevel(StrEnum):
|
||||
@@ -1730,15 +1730,12 @@ def get_metric_type_from_column(column: Any, datasource: Explorable) -> str:
|
||||
:return: The inferred metric type as a string, or an empty string if the
|
||||
column is not a metric or no valid operation is found.
|
||||
"""
|
||||
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
|
||||
metric: SqlMetric = next(
|
||||
(metric for metric in datasource.metrics if metric.metric_name == column),
|
||||
SqlMetric(metric_name=""),
|
||||
metric = next(
|
||||
(m for m in datasource.metrics if m.metric_name == column),
|
||||
None,
|
||||
)
|
||||
|
||||
if metric.metric_name == "":
|
||||
if metric is None:
|
||||
return ""
|
||||
|
||||
expression: str = metric.expression
|
||||
@@ -1784,7 +1781,7 @@ def extract_dataframe_dtypes(
|
||||
|
||||
generic_types: list[GenericDataType] = []
|
||||
for column in df.columns:
|
||||
column_object = columns_by_name.get(column)
|
||||
column_object = columns_by_name.get(str(column))
|
||||
series = df[column]
|
||||
inferred_type: str = ""
|
||||
if series.isna().all():
|
||||
@@ -1814,11 +1811,17 @@ def extract_dataframe_dtypes(
|
||||
return generic_types
|
||||
|
||||
|
||||
def extract_column_dtype(col: TableColumn) -> GenericDataType:
|
||||
if col.is_temporal:
|
||||
def extract_column_dtype(col: ColumnMetadata) -> GenericDataType:
|
||||
# Check for temporal type
|
||||
if hasattr(col, "is_temporal") and col.is_temporal:
|
||||
return GenericDataType.TEMPORAL
|
||||
if col.is_numeric:
|
||||
if col.is_dttm:
|
||||
return GenericDataType.TEMPORAL
|
||||
|
||||
# Check for numeric type
|
||||
if hasattr(col, "is_numeric") and col.is_numeric:
|
||||
return GenericDataType.NUMERIC
|
||||
|
||||
# TODO: add check for boolean data type when proper support is added
|
||||
return GenericDataType.STRING
|
||||
|
||||
@@ -1832,9 +1835,7 @@ def get_time_filter_status(
|
||||
applied_time_extras: dict[str, str],
|
||||
) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
|
||||
temporal_columns: set[Any] = {
|
||||
(col.column_name if hasattr(col, "column_name") else col.get("column_name"))
|
||||
for col in datasource.columns
|
||||
if (col.is_dttm if hasattr(col, "is_dttm") else col.get("is_dttm"))
|
||||
col.column_name for col in datasource.columns if col.is_dttm
|
||||
}
|
||||
applied: list[dict[str, str]] = []
|
||||
rejected: list[dict[str, str]] = []
|
||||
|
||||
@@ -626,7 +626,8 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
||||
assert response == {
|
||||
"message": {
|
||||
"datasource_type": [
|
||||
"Must be one of: table, dataset, query, saved_query, view."
|
||||
"Must be one of: table, dataset, query, saved_query, view, "
|
||||
"semantic_view."
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -981,7 +982,8 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
|
||||
assert response == {
|
||||
"message": {
|
||||
"datasource_type": [
|
||||
"Must be one of: table, dataset, query, saved_query, view."
|
||||
"Must be one of: table, dataset, query, saved_query, view, "
|
||||
"semantic_view."
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,3 +204,55 @@ class TestDatasourceApi(SupersetTestCase):
|
||||
assert rv.status_code == 200
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
assert response["result"] == []
|
||||
|
||||
@patch("superset.datasource.api.security_manager.can_access")
|
||||
@patch("superset.datasource.api.GetCombinedDatasourceListCommand.run")
|
||||
def test_combined_list_invalid_order_column(
|
||||
self,
|
||||
run_mock,
|
||||
can_access_mock,
|
||||
):
|
||||
security_manager.add_permission_view_menu("can_combined_list", "Datasource")
|
||||
perm = security_manager.find_permission_view_menu(
|
||||
"can_combined_list", "Datasource"
|
||||
)
|
||||
admin_role = security_manager.find_role("Admin")
|
||||
security_manager.add_permission_role(admin_role, perm)
|
||||
can_access_mock.side_effect = [True, True]
|
||||
run_mock.side_effect = ValueError("Invalid order column: invalid")
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
rv = self.client.get(
|
||||
"api/v1/datasource/?q=(order_column:invalid,order_direction:desc,page:0,page_size:25)"
|
||||
)
|
||||
|
||||
assert rv.status_code == 400
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
assert response["message"] == "Invalid order column: invalid"
|
||||
|
||||
@patch("superset.datasource.api.security_manager.can_access")
|
||||
@patch("superset.datasource.api.GetCombinedDatasourceListCommand.run")
|
||||
def test_combined_list_semantic_layers_off(
|
||||
self,
|
||||
run_mock,
|
||||
can_access_mock,
|
||||
):
|
||||
security_manager.add_permission_view_menu("can_combined_list", "Datasource")
|
||||
perm = security_manager.find_permission_view_menu(
|
||||
"can_combined_list", "Datasource"
|
||||
)
|
||||
admin_role = security_manager.find_role("Admin")
|
||||
security_manager.add_permission_role(admin_role, perm)
|
||||
can_access_mock.return_value = True
|
||||
run_mock.return_value = {"count": 1, "result": []}
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
with patch("superset.datasource.api.is_feature_enabled", return_value=False):
|
||||
rv = self.client.get(
|
||||
"api/v1/datasource/?q=(order_column:changed_on_delta_humanized,order_direction:desc,page:0,page_size:25)"
|
||||
)
|
||||
|
||||
assert rv.status_code == 200
|
||||
run_mock.assert_called_once()
|
||||
_, kwargs = run_mock.call_args
|
||||
assert kwargs == {}
|
||||
|
||||
16
tests/unit_tests/commands/datasource/__init__.py
Normal file
16
tests/unit_tests/commands/datasource/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
141
tests/unit_tests/commands/datasource/list_test.py
Normal file
141
tests/unit_tests/commands/datasource/list_test.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import literal, select
|
||||
|
||||
from superset.commands.datasource.list import GetCombinedDatasourceListCommand
|
||||
|
||||
|
||||
def test_parse_filters_semantic_view_requires_dataset_operator() -> None:
|
||||
source_type, name_filter, sql_filter, type_filter = (
|
||||
GetCombinedDatasourceListCommand._parse_filters(
|
||||
[{"col": "sql", "opr": "eq", "value": "semantic_view"}]
|
||||
)
|
||||
)
|
||||
|
||||
assert source_type == "all"
|
||||
assert name_filter is None
|
||||
assert sql_filter is None
|
||||
assert type_filter is None
|
||||
|
||||
|
||||
def test_parse_filters_semantic_view_with_dataset_operator() -> None:
|
||||
source_type, name_filter, sql_filter, type_filter = (
|
||||
GetCombinedDatasourceListCommand._parse_filters(
|
||||
[
|
||||
{
|
||||
"col": "sql",
|
||||
"opr": "dataset_is_null_or_empty",
|
||||
"value": "semantic_view",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert source_type == "all"
|
||||
assert name_filter is None
|
||||
assert sql_filter is None
|
||||
assert type_filter == "semantic_view"
|
||||
|
||||
|
||||
def test_parse_filters_sql_bool_requires_dataset_operator() -> None:
|
||||
source_type, name_filter, sql_filter, type_filter = (
|
||||
GetCombinedDatasourceListCommand._parse_filters(
|
||||
[{"col": "sql", "opr": "eq", "value": True}]
|
||||
)
|
||||
)
|
||||
|
||||
assert source_type == "all"
|
||||
assert name_filter is None
|
||||
assert sql_filter is None
|
||||
assert type_filter is None
|
||||
|
||||
|
||||
def test_resolve_source_type_semantic_view_filter_forces_semantic_layer() -> None:
|
||||
command = GetCombinedDatasourceListCommand(
|
||||
args={},
|
||||
can_read_datasets=True,
|
||||
can_read_semantic_views=True,
|
||||
)
|
||||
|
||||
source_type = command._resolve_source_type(
|
||||
source_type="all",
|
||||
sql_filter=None,
|
||||
type_filter="semantic_view",
|
||||
)
|
||||
|
||||
assert source_type == "semantic_layer"
|
||||
|
||||
|
||||
def test_resolve_source_type_sql_filter_forces_database() -> None:
|
||||
command = GetCombinedDatasourceListCommand(
|
||||
args={},
|
||||
can_read_datasets=True,
|
||||
can_read_semantic_views=True,
|
||||
)
|
||||
|
||||
source_type = command._resolve_source_type(
|
||||
source_type="all",
|
||||
sql_filter=True,
|
||||
type_filter=None,
|
||||
)
|
||||
|
||||
assert source_type == "database"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"order_column",
|
||||
["unknown", "database.database_name", "id"],
|
||||
)
|
||||
def test_run_raises_for_invalid_sort_column(order_column: str) -> None:
|
||||
command = GetCombinedDatasourceListCommand(
|
||||
args={"order_column": order_column, "order_direction": "desc"},
|
||||
can_read_datasets=True,
|
||||
can_read_semantic_views=True,
|
||||
)
|
||||
|
||||
ds_q = select(
|
||||
literal(1).label("item_id"),
|
||||
literal("database").label("source_type"),
|
||||
literal("2026-01-01").label("changed_on"),
|
||||
literal("name").label("table_name"),
|
||||
)
|
||||
sv_q = select(
|
||||
literal(2).label("item_id"),
|
||||
literal("semantic_layer").label("source_type"),
|
||||
literal("2026-01-01").label("changed_on"),
|
||||
literal("name").label("table_name"),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.commands.datasource.list.DatasourceDAO.build_dataset_query",
|
||||
return_value=ds_q,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.datasource.list.DatasourceDAO.build_semantic_view_query",
|
||||
return_value=sv_q,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.datasource.list.DatasourceDAO.paginate_combined_query",
|
||||
side_effect=ValueError(f"Invalid order column: {order_column}"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match=f"Invalid order column: {order_column}"):
|
||||
command.run()
|
||||
16
tests/unit_tests/commands/semantic_layer/__init__.py
Normal file
16
tests/unit_tests/commands/semantic_layer/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
148
tests/unit_tests/commands/semantic_layer/create_test.py
Normal file
148
tests/unit_tests/commands/semantic_layer/create_test.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.commands.semantic_layer.create import CreateSemanticLayerCommand
|
||||
from superset.commands.semantic_layer.exceptions import (
|
||||
SemanticLayerCreateFailedError,
|
||||
SemanticLayerInvalidError,
|
||||
)
|
||||
|
||||
|
||||
def test_create_semantic_layer_success(mocker: MockerFixture) -> None:
|
||||
"""Test successful creation of a semantic layer."""
|
||||
new_model = MagicMock()
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.create.SemanticLayerDAO",
|
||||
)
|
||||
dao.validate_uniqueness.return_value = True
|
||||
dao.create.return_value = new_model
|
||||
|
||||
mock_cls = MagicMock()
|
||||
mocker.patch.dict(
|
||||
"superset.commands.semantic_layer.create.registry",
|
||||
{"snowflake": mock_cls},
|
||||
)
|
||||
|
||||
data = {
|
||||
"name": "My Layer",
|
||||
"type": "snowflake",
|
||||
"configuration": {"account": "test"},
|
||||
}
|
||||
result = CreateSemanticLayerCommand(data).run()
|
||||
|
||||
assert result == new_model
|
||||
dao.create.assert_called_once_with(attributes=data)
|
||||
mock_cls.from_configuration.assert_called_once_with({"account": "test"})
|
||||
|
||||
|
||||
def test_create_semantic_layer_unknown_type(mocker: MockerFixture) -> None:
|
||||
"""Test that SemanticLayerInvalidError is raised for unknown type."""
|
||||
mocker.patch(
|
||||
"superset.commands.semantic_layer.create.SemanticLayerDAO",
|
||||
)
|
||||
mocker.patch.dict(
|
||||
"superset.commands.semantic_layer.create.registry",
|
||||
{},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
data = {
|
||||
"name": "My Layer",
|
||||
"type": "nonexistent",
|
||||
"configuration": {},
|
||||
}
|
||||
with pytest.raises(SemanticLayerInvalidError):
|
||||
CreateSemanticLayerCommand(data).run()
|
||||
|
||||
|
||||
def test_create_semantic_layer_duplicate_name(mocker: MockerFixture) -> None:
|
||||
"""Test that SemanticLayerInvalidError is raised for duplicate names."""
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.create.SemanticLayerDAO",
|
||||
)
|
||||
dao.validate_uniqueness.return_value = False
|
||||
|
||||
mocker.patch.dict(
|
||||
"superset.commands.semantic_layer.create.registry",
|
||||
{"snowflake": MagicMock()},
|
||||
)
|
||||
|
||||
data = {
|
||||
"name": "Duplicate",
|
||||
"type": "snowflake",
|
||||
"configuration": {},
|
||||
}
|
||||
with pytest.raises(SemanticLayerInvalidError):
|
||||
CreateSemanticLayerCommand(data).run()
|
||||
|
||||
|
||||
def test_create_semantic_layer_invalid_configuration(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test that invalid configuration is caught by the @transaction decorator."""
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.create.SemanticLayerDAO",
|
||||
)
|
||||
dao.validate_uniqueness.return_value = True
|
||||
|
||||
mock_cls = MagicMock()
|
||||
mock_cls.from_configuration.side_effect = ValueError("bad config")
|
||||
mocker.patch.dict(
|
||||
"superset.commands.semantic_layer.create.registry",
|
||||
{"snowflake": mock_cls},
|
||||
)
|
||||
|
||||
data = {
|
||||
"name": "My Layer",
|
||||
"type": "snowflake",
|
||||
"configuration": {"bad": "data"},
|
||||
}
|
||||
with pytest.raises(SemanticLayerCreateFailedError):
|
||||
CreateSemanticLayerCommand(data).run()
|
||||
|
||||
|
||||
def test_create_semantic_layer_copies_data(mocker: MockerFixture) -> None:
|
||||
"""Test that the command copies input data and does not mutate it."""
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.create.SemanticLayerDAO",
|
||||
)
|
||||
dao.validate_uniqueness.return_value = True
|
||||
dao.create.return_value = MagicMock()
|
||||
|
||||
mocker.patch.dict(
|
||||
"superset.commands.semantic_layer.create.registry",
|
||||
{"snowflake": MagicMock()},
|
||||
)
|
||||
|
||||
original_data = {
|
||||
"name": "Original",
|
||||
"type": "snowflake",
|
||||
"configuration": {"account": "test"},
|
||||
}
|
||||
CreateSemanticLayerCommand(original_data).run()
|
||||
|
||||
assert original_data == {
|
||||
"name": "Original",
|
||||
"type": "snowflake",
|
||||
"configuration": {"account": "test"},
|
||||
}
|
||||
50
tests/unit_tests/commands/semantic_layer/delete_test.py
Normal file
50
tests/unit_tests/commands/semantic_layer/delete_test.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.commands.semantic_layer.delete import DeleteSemanticLayerCommand
|
||||
from superset.commands.semantic_layer.exceptions import SemanticLayerNotFoundError
|
||||
|
||||
|
||||
def test_delete_semantic_layer_success(mocker: MockerFixture) -> None:
|
||||
"""Test successful deletion of a semantic layer."""
|
||||
mock_model = MagicMock()
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.delete.SemanticLayerDAO",
|
||||
)
|
||||
dao.find_by_uuid.return_value = mock_model
|
||||
|
||||
DeleteSemanticLayerCommand("some-uuid").run()
|
||||
|
||||
dao.find_by_uuid.assert_called_once_with("some-uuid")
|
||||
dao.delete.assert_called_once_with([mock_model])
|
||||
|
||||
|
||||
def test_delete_semantic_layer_not_found(mocker: MockerFixture) -> None:
|
||||
"""Test that SemanticLayerNotFoundError is raised when model is missing."""
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.delete.SemanticLayerDAO",
|
||||
)
|
||||
dao.find_by_uuid.return_value = None
|
||||
|
||||
with pytest.raises(SemanticLayerNotFoundError):
|
||||
DeleteSemanticLayerCommand("missing-uuid").run()
|
||||
91
tests/unit_tests/commands/semantic_layer/exceptions_test.py
Normal file
91
tests/unit_tests/commands/semantic_layer/exceptions_test.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from superset.commands.semantic_layer.exceptions import (
|
||||
SemanticLayerCreateFailedError,
|
||||
SemanticLayerDeleteFailedError,
|
||||
SemanticLayerForbiddenError,
|
||||
SemanticLayerInvalidError,
|
||||
SemanticLayerNotFoundError,
|
||||
SemanticLayerUpdateFailedError,
|
||||
SemanticViewForbiddenError,
|
||||
SemanticViewInvalidError,
|
||||
SemanticViewNotFoundError,
|
||||
SemanticViewUpdateFailedError,
|
||||
)
|
||||
|
||||
|
||||
def test_semantic_view_not_found_error() -> None:
|
||||
"""Test SemanticViewNotFoundError has correct status and message."""
|
||||
error = SemanticViewNotFoundError()
|
||||
assert error.status == 404
|
||||
assert str(error.message) == "Semantic view does not exist"
|
||||
|
||||
|
||||
def test_semantic_view_forbidden_error() -> None:
|
||||
"""Test SemanticViewForbiddenError has correct message."""
|
||||
error = SemanticViewForbiddenError()
|
||||
assert str(error.message) == "Changing this semantic view is forbidden"
|
||||
|
||||
|
||||
def test_semantic_view_invalid_error() -> None:
|
||||
"""Test SemanticViewInvalidError has correct message."""
|
||||
error = SemanticViewInvalidError()
|
||||
assert str(error.message) == "Semantic view parameters are invalid."
|
||||
|
||||
|
||||
def test_semantic_view_update_failed_error() -> None:
|
||||
"""Test SemanticViewUpdateFailedError has correct message."""
|
||||
error = SemanticViewUpdateFailedError()
|
||||
assert str(error.message) == "Semantic view could not be updated."
|
||||
|
||||
|
||||
def test_semantic_layer_not_found_error() -> None:
|
||||
"""Test SemanticLayerNotFoundError has correct status and message."""
|
||||
error = SemanticLayerNotFoundError()
|
||||
assert error.status == 404
|
||||
assert str(error.message) == "Semantic layer does not exist"
|
||||
|
||||
|
||||
def test_semantic_layer_forbidden_error() -> None:
|
||||
"""Test SemanticLayerForbiddenError has correct message."""
|
||||
error = SemanticLayerForbiddenError()
|
||||
assert str(error.message) == "Changing this semantic layer is forbidden"
|
||||
|
||||
|
||||
def test_semantic_layer_invalid_error() -> None:
|
||||
"""Test SemanticLayerInvalidError has correct message."""
|
||||
error = SemanticLayerInvalidError()
|
||||
assert str(error.message) == "Semantic layer parameters are invalid."
|
||||
|
||||
|
||||
def test_semantic_layer_create_failed_error() -> None:
|
||||
"""Test SemanticLayerCreateFailedError has correct message."""
|
||||
error = SemanticLayerCreateFailedError()
|
||||
assert str(error.message) == "Semantic layer could not be created."
|
||||
|
||||
|
||||
def test_semantic_layer_update_failed_error() -> None:
|
||||
"""Test SemanticLayerUpdateFailedError has correct message."""
|
||||
error = SemanticLayerUpdateFailedError()
|
||||
assert str(error.message) == "Semantic layer could not be updated."
|
||||
|
||||
|
||||
def test_semantic_layer_delete_failed_error() -> None:
|
||||
"""Test SemanticLayerDeleteFailedError has correct message."""
|
||||
error = SemanticLayerDeleteFailedError()
|
||||
assert str(error.message) == "Semantic layer could not be deleted."
|
||||
326
tests/unit_tests/commands/semantic_layer/update_test.py
Normal file
326
tests/unit_tests/commands/semantic_layer/update_test.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.commands.semantic_layer.exceptions import (
|
||||
SemanticLayerInvalidError,
|
||||
SemanticLayerNotFoundError,
|
||||
SemanticViewForbiddenError,
|
||||
SemanticViewNotFoundError,
|
||||
)
|
||||
from superset.commands.semantic_layer.update import (
|
||||
UpdateSemanticLayerCommand,
|
||||
UpdateSemanticViewCommand,
|
||||
)
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
|
||||
def test_update_semantic_view_success(mocker: MockerFixture) -> None:
|
||||
"""Test successful update of a semantic view."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.id = 1
|
||||
mock_model.configuration = "{}"
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticViewDAO",
|
||||
)
|
||||
dao.find_by_id.return_value = mock_model
|
||||
dao.update.return_value = mock_model
|
||||
|
||||
mocker.patch(
|
||||
"superset.commands.semantic_layer.update.security_manager",
|
||||
)
|
||||
|
||||
data = {"description": "Updated", "cache_timeout": 300}
|
||||
result = UpdateSemanticViewCommand(1, data).run()
|
||||
|
||||
assert result == mock_model
|
||||
dao.find_by_id.assert_called_once_with(1)
|
||||
dao.update.assert_called_once_with(mock_model, attributes=data)
|
||||
|
||||
|
||||
def test_update_semantic_view_not_found(mocker: MockerFixture) -> None:
|
||||
"""Test that SemanticViewNotFoundError is raised when model is missing."""
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticViewDAO",
|
||||
)
|
||||
dao.find_by_id.return_value = None
|
||||
|
||||
with pytest.raises(SemanticViewNotFoundError):
|
||||
UpdateSemanticViewCommand(999, {"description": "test"}).run()
|
||||
|
||||
|
||||
def test_update_semantic_view_forbidden(mocker: MockerFixture) -> None:
|
||||
"""Test that SemanticViewForbiddenError is raised on ownership failure."""
|
||||
mock_model = MagicMock()
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticViewDAO",
|
||||
)
|
||||
dao.find_by_id.return_value = mock_model
|
||||
|
||||
sm = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.security_manager",
|
||||
)
|
||||
# Use a regular MagicMock for raise_for_ownership to avoid AsyncMock issues
|
||||
sm.raise_for_ownership = MagicMock(
|
||||
side_effect=SupersetSecurityException(MagicMock()),
|
||||
)
|
||||
|
||||
with pytest.raises(SemanticViewForbiddenError):
|
||||
UpdateSemanticViewCommand(1, {"description": "test"}).run()
|
||||
|
||||
|
||||
def test_update_semantic_view_copies_data(mocker: MockerFixture) -> None:
|
||||
"""Test that the command copies input data and does not mutate it."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.configuration = "{}"
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticViewDAO",
|
||||
)
|
||||
dao.find_by_id.return_value = mock_model
|
||||
dao.update.return_value = mock_model
|
||||
|
||||
mocker.patch(
|
||||
"superset.commands.semantic_layer.update.security_manager",
|
||||
)
|
||||
|
||||
original_data = {"description": "Original"}
|
||||
UpdateSemanticViewCommand(1, original_data).run()
|
||||
|
||||
# The original dict should not have been modified
|
||||
assert original_data == {"description": "Original"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# UpdateSemanticLayerCommand tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_update_semantic_layer_success(mocker: MockerFixture) -> None:
|
||||
"""Test successful update of a semantic layer."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.type = "snowflake"
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticLayerDAO",
|
||||
)
|
||||
dao.find_by_uuid.return_value = mock_model
|
||||
dao.update.return_value = mock_model
|
||||
|
||||
data = {"name": "Updated", "description": "New desc"}
|
||||
result = UpdateSemanticLayerCommand("some-uuid", data).run()
|
||||
|
||||
assert result == mock_model
|
||||
dao.find_by_uuid.assert_called_once_with("some-uuid")
|
||||
dao.update.assert_called_once_with(mock_model, attributes=data)
|
||||
|
||||
|
||||
def test_update_semantic_layer_not_found(mocker: MockerFixture) -> None:
|
||||
"""Test that SemanticLayerNotFoundError is raised when model is missing."""
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticLayerDAO",
|
||||
)
|
||||
dao.find_by_uuid.return_value = None
|
||||
|
||||
with pytest.raises(SemanticLayerNotFoundError):
|
||||
UpdateSemanticLayerCommand("missing-uuid", {"name": "test"}).run()
|
||||
|
||||
|
||||
def test_update_semantic_layer_duplicate_name(mocker: MockerFixture) -> None:
|
||||
"""Test that SemanticLayerInvalidError is raised for duplicate names."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.type = "snowflake"
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticLayerDAO",
|
||||
)
|
||||
dao.find_by_uuid.return_value = mock_model
|
||||
dao.validate_update_uniqueness.return_value = False
|
||||
|
||||
with pytest.raises(SemanticLayerInvalidError):
|
||||
UpdateSemanticLayerCommand("some-uuid", {"name": "Duplicate"}).run()
|
||||
|
||||
|
||||
def test_update_semantic_layer_validates_configuration(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test that configuration is validated against the plugin."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.type = "snowflake"
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticLayerDAO",
|
||||
)
|
||||
dao.find_by_uuid.return_value = mock_model
|
||||
dao.update.return_value = mock_model
|
||||
|
||||
mock_cls = MagicMock()
|
||||
mocker.patch.dict(
|
||||
"superset.commands.semantic_layer.update.registry",
|
||||
{"snowflake": mock_cls},
|
||||
)
|
||||
|
||||
config = {"account": "test"}
|
||||
UpdateSemanticLayerCommand("some-uuid", {"configuration": config}).run()
|
||||
|
||||
mock_cls.from_configuration.assert_called_once_with(config)
|
||||
|
||||
|
||||
def test_update_semantic_layer_skips_name_check_when_no_name(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test that name uniqueness is not checked when name is not provided."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.type = "snowflake"
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticLayerDAO",
|
||||
)
|
||||
dao.find_by_uuid.return_value = mock_model
|
||||
dao.update.return_value = mock_model
|
||||
|
||||
UpdateSemanticLayerCommand("some-uuid", {"description": "Updated"}).run()
|
||||
|
||||
dao.validate_update_uniqueness.assert_not_called()
|
||||
|
||||
|
||||
def test_update_semantic_layer_copies_data(mocker: MockerFixture) -> None:
|
||||
"""Test that the command copies input data and does not mutate it."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.type = "snowflake"
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticLayerDAO",
|
||||
)
|
||||
dao.find_by_uuid.return_value = mock_model
|
||||
dao.update.return_value = mock_model
|
||||
|
||||
original_data = {"description": "Original"}
|
||||
UpdateSemanticLayerCommand("some-uuid", original_data).run()
|
||||
|
||||
assert original_data == {"description": "Original"}
|
||||
|
||||
|
||||
def _make_view_model(
|
||||
uuid: str = "view-uuid-1",
|
||||
name: str = "my_view",
|
||||
layer_uuid: str = "layer-uuid-1",
|
||||
configuration: str = '{"schema": "prod"}',
|
||||
) -> MagicMock:
|
||||
model = MagicMock()
|
||||
model.uuid = uuid
|
||||
model.name = name
|
||||
model.semantic_layer_uuid = layer_uuid
|
||||
model.configuration = configuration
|
||||
return model
|
||||
|
||||
|
||||
def test_update_uniqueness_different_config_same_name(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Same name but different configuration is allowed."""
|
||||
mock_model = _make_view_model(configuration='{"schema": "prod"}')
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticViewDAO",
|
||||
)
|
||||
dao.find_by_id.return_value = mock_model
|
||||
dao.update.return_value = mock_model
|
||||
dao.validate_update_uniqueness.return_value = True
|
||||
|
||||
mocker.patch(
|
||||
"superset.commands.semantic_layer.update.security_manager",
|
||||
)
|
||||
|
||||
# Update to a config that differs from an existing view
|
||||
data = {"name": "my_view", "configuration": {"schema": "testing"}}
|
||||
result = UpdateSemanticViewCommand(1, data).run()
|
||||
|
||||
assert result == mock_model
|
||||
dao.validate_update_uniqueness.assert_called_once_with(
|
||||
view_uuid="view-uuid-1",
|
||||
name="my_view",
|
||||
layer_uuid="layer-uuid-1",
|
||||
configuration={"schema": "testing"},
|
||||
)
|
||||
|
||||
|
||||
def test_update_uniqueness_same_config_different_name(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Same configuration but different name is allowed."""
|
||||
mock_model = _make_view_model(configuration='{"schema": "prod"}')
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticViewDAO",
|
||||
)
|
||||
dao.find_by_id.return_value = mock_model
|
||||
dao.update.return_value = mock_model
|
||||
dao.validate_update_uniqueness.return_value = True
|
||||
|
||||
mocker.patch(
|
||||
"superset.commands.semantic_layer.update.security_manager",
|
||||
)
|
||||
|
||||
data = {"name": "renamed_view", "configuration": {"schema": "prod"}}
|
||||
result = UpdateSemanticViewCommand(1, data).run()
|
||||
|
||||
assert result == mock_model
|
||||
dao.validate_update_uniqueness.assert_called_once_with(
|
||||
view_uuid="view-uuid-1",
|
||||
name="renamed_view",
|
||||
layer_uuid="layer-uuid-1",
|
||||
configuration={"schema": "prod"},
|
||||
)
|
||||
|
||||
|
||||
def test_update_uniqueness_same_config_same_name_fails(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Same name and same configuration is a duplicate."""
|
||||
mock_model = _make_view_model(configuration='{"schema": "prod"}')
|
||||
|
||||
dao = mocker.patch(
|
||||
"superset.commands.semantic_layer.update.SemanticViewDAO",
|
||||
)
|
||||
dao.find_by_id.return_value = mock_model
|
||||
dao.validate_update_uniqueness.return_value = False
|
||||
|
||||
mocker.patch(
|
||||
"superset.commands.semantic_layer.update.security_manager",
|
||||
)
|
||||
|
||||
from superset.commands.semantic_layer.exceptions import (
|
||||
SemanticViewUpdateFailedError,
|
||||
)
|
||||
|
||||
data = {"name": "my_view", "configuration": {"schema": "prod"}}
|
||||
with pytest.raises(SemanticViewUpdateFailedError):
|
||||
UpdateSemanticViewCommand(1, data).run()
|
||||
|
||||
dao.validate_update_uniqueness.assert_called_once_with(
|
||||
view_uuid="view-uuid-1",
|
||||
name="my_view",
|
||||
layer_uuid="layer-uuid-1",
|
||||
configuration={"schema": "prod"},
|
||||
)
|
||||
@@ -18,6 +18,7 @@
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import literal, select
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.utils.core import DatasourceType
|
||||
@@ -138,3 +139,31 @@ def test_not_found_datasource(session_with_data: Session) -> None:
|
||||
datasource_type="table",
|
||||
database_id_or_uuid=500000,
|
||||
)
|
||||
|
||||
|
||||
def test_escape_ilike_fragment() -> None:
|
||||
from superset.daos.datasource import _escape_ilike_fragment
|
||||
|
||||
assert _escape_ilike_fragment("foo%bar_baz\\") == "foo\\%bar\\_baz\\\\"
|
||||
|
||||
|
||||
def test_paginate_combined_query_invalid_sort_column() -> None:
|
||||
from superset.daos.datasource import DatasourceDAO
|
||||
|
||||
combined = (
|
||||
select(
|
||||
literal(1).label("item_id"),
|
||||
literal("database").label("source_type"),
|
||||
literal("2026-01-01").label("changed_on"),
|
||||
literal("name").label("table_name"),
|
||||
)
|
||||
).subquery()
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid order column: invalid"):
|
||||
DatasourceDAO.paginate_combined_query(
|
||||
combined=combined,
|
||||
order_column="invalid",
|
||||
order_direction="desc",
|
||||
page=0,
|
||||
page_size=25,
|
||||
)
|
||||
|
||||
912
tests/unit_tests/semantic_layers/api_test.py
Normal file
912
tests/unit_tests/semantic_layers/api_test.py
Normal file
@@ -0,0 +1,912 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import uuid as uuid_lib
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.commands.semantic_layer.exceptions import (
|
||||
SemanticLayerCreateFailedError,
|
||||
SemanticLayerDeleteFailedError,
|
||||
SemanticLayerInvalidError,
|
||||
SemanticLayerNotFoundError,
|
||||
SemanticLayerUpdateFailedError,
|
||||
SemanticViewForbiddenError,
|
||||
SemanticViewInvalidError,
|
||||
SemanticViewNotFoundError,
|
||||
SemanticViewUpdateFailedError,
|
||||
)
|
||||
|
||||
SEMANTIC_LAYERS_APP = pytest.mark.parametrize(
|
||||
"app",
|
||||
[{"FEATURE_FLAGS": {"SEMANTIC_LAYERS": True}}],
|
||||
indirect=True,
|
||||
)
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_view(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test successful PUT updates a semantic view."""
|
||||
changed_model = MagicMock()
|
||||
changed_model.id = 1
|
||||
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticViewCommand",
|
||||
)
|
||||
mock_command.return_value.run.return_value = changed_model
|
||||
|
||||
payload = {"description": "Updated description", "cache_timeout": 300}
|
||||
response = client.put(
|
||||
"/api/v1/semantic_view/1",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["id"] == 1
|
||||
assert response.json["result"] == payload
|
||||
mock_command.assert_called_once_with("1", payload)
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_view_not_found(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT returns 404 when semantic view does not exist."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticViewCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticViewNotFoundError()
|
||||
|
||||
response = client.put(
|
||||
"/api/v1/semantic_view/999",
|
||||
json={"description": "Updated"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_view_forbidden(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT returns 403 when user lacks ownership."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticViewCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticViewForbiddenError()
|
||||
|
||||
response = client.put(
|
||||
"/api/v1/semantic_view/1",
|
||||
json={"description": "Updated"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_view_invalid(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT returns 422 when validation fails."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticViewCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticViewInvalidError()
|
||||
|
||||
response = client.put(
|
||||
"/api/v1/semantic_view/1",
|
||||
json={"description": "Updated"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_view_update_failed(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT returns 422 when the update operation fails."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticViewCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticViewUpdateFailedError()
|
||||
|
||||
response = client.put(
|
||||
"/api/v1/semantic_view/1",
|
||||
json={"description": "Updated"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_view_bad_request(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT returns 400 when the request payload has invalid fields."""
|
||||
# Marshmallow raises ValidationError for unknown fields
|
||||
mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticViewCommand",
|
||||
)
|
||||
|
||||
response = client.put(
|
||||
"/api/v1/semantic_view/1",
|
||||
json={"invalid_field": "value"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_view_description_only(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT with only description field."""
|
||||
changed_model = MagicMock()
|
||||
changed_model.id = 1
|
||||
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticViewCommand",
|
||||
)
|
||||
mock_command.return_value.run.return_value = changed_model
|
||||
|
||||
payload = {"description": "New description"}
|
||||
response = client.put(
|
||||
"/api/v1/semantic_view/1",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["result"] == payload
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_view_cache_timeout_only(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT with only cache_timeout field."""
|
||||
changed_model = MagicMock()
|
||||
changed_model.id = 2
|
||||
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticViewCommand",
|
||||
)
|
||||
mock_command.return_value.run.return_value = changed_model
|
||||
|
||||
payload = {"cache_timeout": 600}
|
||||
response = client.put(
|
||||
"/api/v1/semantic_view/2",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["id"] == 2
|
||||
assert response.json["result"] == payload
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_view_null_values(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT with null values for both fields."""
|
||||
changed_model = MagicMock()
|
||||
changed_model.id = 1
|
||||
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticViewCommand",
|
||||
)
|
||||
mock_command.return_value.run.return_value = changed_model
|
||||
|
||||
payload = {"description": None, "cache_timeout": None}
|
||||
response = client.put(
|
||||
"/api/v1/semantic_view/1",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["result"] == payload
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_view_empty_payload(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT with empty payload."""
|
||||
changed_model = MagicMock()
|
||||
changed_model.id = 1
|
||||
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticViewCommand",
|
||||
)
|
||||
mock_command.return_value.run.return_value = changed_model
|
||||
|
||||
response = client.put(
|
||||
"/api/v1/semantic_view/1",
|
||||
json={},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SemanticLayerRestApi tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_get_types(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test GET /types returns registered semantic layer types."""
|
||||
mock_cls = MagicMock()
|
||||
mock_cls.name = "Snowflake Semantic Layer"
|
||||
mock_cls.description = "Connect to Snowflake."
|
||||
|
||||
mocker.patch.dict(
|
||||
"superset.semantic_layers.api.registry",
|
||||
{"snowflake": mock_cls},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/semantic_layer/types")
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json["result"]
|
||||
assert len(result) == 1
|
||||
assert result[0] == {
|
||||
"id": "snowflake",
|
||||
"name": "Snowflake Semantic Layer",
|
||||
"description": "Connect to Snowflake.",
|
||||
}
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_get_types_empty(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test GET /types returns empty list when no types registered."""
|
||||
mocker.patch.dict(
|
||||
"superset.semantic_layers.api.registry",
|
||||
{},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/semantic_layer/types")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["result"] == []
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_configuration_schema(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST /schema/configuration returns schema without partial config."""
|
||||
mock_cls = MagicMock()
|
||||
mock_cls.get_configuration_schema.return_value = {"type": "object"}
|
||||
|
||||
mocker.patch.dict(
|
||||
"superset.semantic_layers.api.registry",
|
||||
{"snowflake": mock_cls},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/semantic_layer/schema/configuration",
|
||||
json={"type": "snowflake"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["result"] == {"type": "object"}
|
||||
mock_cls.get_configuration_schema.assert_called_once_with(None)
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_configuration_schema_with_partial_config(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST /schema/configuration enriches schema with partial config."""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.configuration = {"account": "test"}
|
||||
|
||||
mock_cls = MagicMock()
|
||||
mock_cls.from_configuration.return_value = mock_instance
|
||||
mock_cls.get_configuration_schema.return_value = {
|
||||
"type": "object",
|
||||
"properties": {"database": {"enum": ["db1", "db2"]}},
|
||||
}
|
||||
|
||||
mocker.patch.dict(
|
||||
"superset.semantic_layers.api.registry",
|
||||
{"snowflake": mock_cls},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/semantic_layer/schema/configuration",
|
||||
json={"type": "snowflake", "configuration": {"account": "test"}},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_cls.get_configuration_schema.assert_called_once_with({"account": "test"})
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_configuration_schema_with_invalid_partial_config(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test /schema/configuration returns schema when partial config fails."""
|
||||
mock_cls = MagicMock()
|
||||
mock_cls.from_configuration.side_effect = ValueError("bad config")
|
||||
mock_cls.get_configuration_schema.return_value = {"type": "object"}
|
||||
|
||||
mocker.patch.dict(
|
||||
"superset.semantic_layers.api.registry",
|
||||
{"snowflake": mock_cls},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/semantic_layer/schema/configuration",
|
||||
json={"type": "snowflake", "configuration": {"bad": "data"}},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_cls.get_configuration_schema.assert_called_once_with(None)
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_configuration_schema_unknown_type(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST /schema/configuration returns 400 for unknown type."""
|
||||
mocker.patch.dict(
|
||||
"superset.semantic_layers.api.registry",
|
||||
{},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/semantic_layer/schema/configuration",
|
||||
json={"type": "nonexistent"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_runtime_schema(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST /<uuid>/schema/runtime returns runtime schema."""
|
||||
test_uuid = str(uuid_lib.uuid4())
|
||||
mock_layer = MagicMock()
|
||||
mock_layer.type = "snowflake"
|
||||
mock_layer.implementation.configuration = {"account": "test"}
|
||||
|
||||
mock_dao = mocker.patch("superset.semantic_layers.api.SemanticLayerDAO")
|
||||
mock_dao.find_by_uuid.return_value = mock_layer
|
||||
|
||||
mock_cls = MagicMock()
|
||||
mock_cls.get_runtime_schema.return_value = {"type": "object"}
|
||||
|
||||
mocker.patch.dict(
|
||||
"superset.semantic_layers.api.registry",
|
||||
{"snowflake": mock_cls},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
f"/api/v1/semantic_layer/{test_uuid}/schema/runtime",
|
||||
json={"runtime_data": {"database": "mydb"}},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["result"] == {"type": "object"}
|
||||
mock_cls.get_runtime_schema.assert_called_once_with(
|
||||
{"account": "test"}, {"database": "mydb"}
|
||||
)
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_runtime_schema_no_body(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST /<uuid>/schema/runtime works without a request body."""
|
||||
test_uuid = str(uuid_lib.uuid4())
|
||||
mock_layer = MagicMock()
|
||||
mock_layer.type = "snowflake"
|
||||
mock_layer.implementation.configuration = {"account": "test"}
|
||||
|
||||
mock_dao = mocker.patch("superset.semantic_layers.api.SemanticLayerDAO")
|
||||
mock_dao.find_by_uuid.return_value = mock_layer
|
||||
|
||||
mock_cls = MagicMock()
|
||||
mock_cls.get_runtime_schema.return_value = {"type": "object"}
|
||||
|
||||
mocker.patch.dict(
|
||||
"superset.semantic_layers.api.registry",
|
||||
{"snowflake": mock_cls},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
f"/api/v1/semantic_layer/{test_uuid}/schema/runtime",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_cls.get_runtime_schema.assert_called_once_with({"account": "test"}, None)
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_runtime_schema_not_found(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST /<uuid>/schema/runtime returns 404 when layer not found."""
|
||||
mock_dao = mocker.patch("superset.semantic_layers.api.SemanticLayerDAO")
|
||||
mock_dao.find_by_uuid.return_value = None
|
||||
|
||||
response = client.post(
|
||||
f"/api/v1/semantic_layer/{uuid_lib.uuid4()}/schema/runtime",
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_runtime_schema_unknown_type(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST /<uuid>/schema/runtime returns 400 for unknown type."""
|
||||
test_uuid = str(uuid_lib.uuid4())
|
||||
mock_layer = MagicMock()
|
||||
mock_layer.type = "unknown_type"
|
||||
|
||||
mock_dao = mocker.patch("superset.semantic_layers.api.SemanticLayerDAO")
|
||||
mock_dao.find_by_uuid.return_value = mock_layer
|
||||
|
||||
mocker.patch.dict(
|
||||
"superset.semantic_layers.api.registry",
|
||||
{},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
f"/api/v1/semantic_layer/{test_uuid}/schema/runtime",
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Unknown type" in response.json["message"]
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_runtime_schema_exception(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST /<uuid>/schema/runtime returns 400 when schema raises."""
|
||||
test_uuid = str(uuid_lib.uuid4())
|
||||
mock_layer = MagicMock()
|
||||
mock_layer.type = "snowflake"
|
||||
mock_layer.implementation.configuration = {"account": "test"}
|
||||
|
||||
mock_dao = mocker.patch("superset.semantic_layers.api.SemanticLayerDAO")
|
||||
mock_dao.find_by_uuid.return_value = mock_layer
|
||||
|
||||
mock_cls = MagicMock()
|
||||
mock_cls.get_runtime_schema.side_effect = ValueError("Bad config")
|
||||
|
||||
mocker.patch.dict(
|
||||
"superset.semantic_layers.api.registry",
|
||||
{"snowflake": mock_cls},
|
||||
clear=True,
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
f"/api/v1/semantic_layer/{test_uuid}/schema/runtime",
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Bad config" in response.json["message"]
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_post_semantic_layer(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST / creates a semantic layer."""
|
||||
test_uuid = uuid_lib.uuid4()
|
||||
new_model = MagicMock()
|
||||
new_model.uuid = test_uuid
|
||||
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.CreateSemanticLayerCommand",
|
||||
)
|
||||
mock_command.return_value.run.return_value = new_model
|
||||
|
||||
payload = {
|
||||
"name": "My Layer",
|
||||
"type": "snowflake",
|
||||
"configuration": {"account": "test"},
|
||||
}
|
||||
response = client.post("/api/v1/semantic_layer/", json=payload)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json["result"]["uuid"] == str(test_uuid)
|
||||
mock_command.assert_called_once_with(payload)
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_post_semantic_layer_invalid(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST / returns 422 when validation fails."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.CreateSemanticLayerCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticLayerInvalidError(
|
||||
"Unknown type: bad"
|
||||
)
|
||||
|
||||
payload = {
|
||||
"name": "My Layer",
|
||||
"type": "bad",
|
||||
"configuration": {},
|
||||
}
|
||||
response = client.post("/api/v1/semantic_layer/", json=payload)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_post_semantic_layer_create_failed(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST / returns 422 when creation fails."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.CreateSemanticLayerCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticLayerCreateFailedError()
|
||||
|
||||
payload = {
|
||||
"name": "My Layer",
|
||||
"type": "snowflake",
|
||||
"configuration": {"account": "test"},
|
||||
}
|
||||
response = client.post("/api/v1/semantic_layer/", json=payload)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_post_semantic_layer_missing_required_fields(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test POST / returns 400 when required fields are missing."""
|
||||
mocker.patch(
|
||||
"superset.semantic_layers.api.CreateSemanticLayerCommand",
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/semantic_layer/",
|
||||
json={"name": "Only name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_layer(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT /<uuid> updates a semantic layer."""
|
||||
test_uuid = uuid_lib.uuid4()
|
||||
changed_model = MagicMock()
|
||||
changed_model.uuid = test_uuid
|
||||
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticLayerCommand",
|
||||
)
|
||||
mock_command.return_value.run.return_value = changed_model
|
||||
|
||||
payload = {"name": "Updated Name"}
|
||||
response = client.put(
|
||||
f"/api/v1/semantic_layer/{test_uuid}",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["result"]["uuid"] == str(test_uuid)
|
||||
mock_command.assert_called_once_with(str(test_uuid), payload)
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_layer_not_found(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT /<uuid> returns 404 when layer not found."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticLayerCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticLayerNotFoundError()
|
||||
|
||||
response = client.put(
|
||||
f"/api/v1/semantic_layer/{uuid_lib.uuid4()}",
|
||||
json={"name": "New"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_layer_invalid(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT /<uuid> returns 422 when validation fails."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticLayerCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticLayerInvalidError(
|
||||
"Name already exists"
|
||||
)
|
||||
|
||||
response = client.put(
|
||||
f"/api/v1/semantic_layer/{uuid_lib.uuid4()}",
|
||||
json={"name": "Duplicate"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_layer_update_failed(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test PUT /<uuid> returns 422 when update fails."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.UpdateSemanticLayerCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticLayerUpdateFailedError()
|
||||
|
||||
response = client.put(
|
||||
f"/api/v1/semantic_layer/{uuid_lib.uuid4()}",
|
||||
json={"name": "Test"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_put_semantic_layer_validation_error(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""Test PUT /<uuid> returns 400 when payload fails schema validation."""
|
||||
response = client.put(
|
||||
f"/api/v1/semantic_layer/{uuid_lib.uuid4()}",
|
||||
json={"cache_timeout": "not_a_number"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_delete_semantic_layer(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test DELETE /<uuid> deletes a semantic layer."""
|
||||
test_uuid = str(uuid_lib.uuid4())
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.DeleteSemanticLayerCommand",
|
||||
)
|
||||
mock_command.return_value.run.return_value = None
|
||||
|
||||
response = client.delete(f"/api/v1/semantic_layer/{test_uuid}")
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_command.assert_called_once_with(test_uuid)
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_delete_semantic_layer_not_found(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test DELETE /<uuid> returns 404 when layer not found."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.DeleteSemanticLayerCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticLayerNotFoundError()
|
||||
|
||||
response = client.delete(f"/api/v1/semantic_layer/{uuid_lib.uuid4()}")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_delete_semantic_layer_failed(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test DELETE /<uuid> returns 422 when deletion fails."""
|
||||
mock_command = mocker.patch(
|
||||
"superset.semantic_layers.api.DeleteSemanticLayerCommand",
|
||||
)
|
||||
mock_command.return_value.run.side_effect = SemanticLayerDeleteFailedError()
|
||||
|
||||
response = client.delete(f"/api/v1/semantic_layer/{uuid_lib.uuid4()}")
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_get_list_semantic_layers(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test GET / returns list of semantic layers."""
|
||||
layer1 = MagicMock()
|
||||
layer1.uuid = uuid_lib.uuid4()
|
||||
layer1.name = "Layer 1"
|
||||
layer1.description = "First"
|
||||
layer1.type = "snowflake"
|
||||
layer1.cache_timeout = None
|
||||
|
||||
layer2 = MagicMock()
|
||||
layer2.uuid = uuid_lib.uuid4()
|
||||
layer2.name = "Layer 2"
|
||||
layer2.description = None
|
||||
layer2.type = "snowflake"
|
||||
layer2.cache_timeout = 300
|
||||
|
||||
mock_dao = mocker.patch("superset.semantic_layers.api.SemanticLayerDAO")
|
||||
mock_dao.find_all.return_value = [layer1, layer2]
|
||||
|
||||
response = client.get("/api/v1/semantic_layer/")
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json["result"]
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "Layer 1"
|
||||
assert result[1]["name"] == "Layer 2"
|
||||
assert result[1]["cache_timeout"] == 300
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_get_list_semantic_layers_empty(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test GET / returns empty list when no layers exist."""
|
||||
mock_dao = mocker.patch("superset.semantic_layers.api.SemanticLayerDAO")
|
||||
mock_dao.find_all.return_value = []
|
||||
|
||||
response = client.get("/api/v1/semantic_layer/")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["result"] == []
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_get_semantic_layer(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test GET /<uuid> returns a single semantic layer."""
|
||||
test_uuid = uuid_lib.uuid4()
|
||||
layer = MagicMock()
|
||||
layer.uuid = test_uuid
|
||||
layer.name = "My Layer"
|
||||
layer.description = "A layer"
|
||||
layer.type = "snowflake"
|
||||
layer.cache_timeout = 600
|
||||
|
||||
mock_dao = mocker.patch("superset.semantic_layers.api.SemanticLayerDAO")
|
||||
mock_dao.find_by_uuid.return_value = layer
|
||||
|
||||
response = client.get(f"/api/v1/semantic_layer/{test_uuid}")
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json["result"]
|
||||
assert result["uuid"] == str(test_uuid)
|
||||
assert result["name"] == "My Layer"
|
||||
assert result["type"] == "snowflake"
|
||||
assert result["cache_timeout"] == 600
|
||||
|
||||
|
||||
@SEMANTIC_LAYERS_APP
|
||||
def test_get_semantic_layer_not_found(
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""Test GET /<uuid> returns 404 when layer not found."""
|
||||
mock_dao = mocker.patch("superset.semantic_layers.api.SemanticLayerDAO")
|
||||
mock_dao.find_by_uuid.return_value = None
|
||||
|
||||
response = client.get(f"/api/v1/semantic_layer/{uuid_lib.uuid4()}")
|
||||
|
||||
assert response.status_code == 404
|
||||
85
tests/unit_tests/semantic_layers/dao_test.py
Normal file
85
tests/unit_tests/semantic_layers/dao_test.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for SemanticViewDAO."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_semantic_view(session: Session) -> Iterator[Session]:
|
||||
"""Create an in-memory DB with a SemanticLayer and one SemanticView."""
|
||||
from superset.semantic_layers.models import SemanticLayer, SemanticView
|
||||
|
||||
engine = session.get_bind()
|
||||
SemanticView.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
layer = SemanticLayer(
|
||||
uuid=uuid.uuid4(),
|
||||
name="test_layer",
|
||||
type="test",
|
||||
configuration="{}",
|
||||
)
|
||||
session.add(layer)
|
||||
session.flush()
|
||||
|
||||
view = SemanticView(
|
||||
id=1,
|
||||
uuid=uuid.uuid4(),
|
||||
name="test_view",
|
||||
semantic_layer_uuid=layer.uuid,
|
||||
configuration="{}",
|
||||
)
|
||||
session.add(view)
|
||||
session.flush()
|
||||
|
||||
return session
|
||||
|
||||
|
||||
def test_find_by_id_uses_integer_id_column(
|
||||
session_with_semantic_view: Session,
|
||||
) -> None:
|
||||
"""
|
||||
SemanticViewDAO.find_by_id must look up by the integer ``id`` column, not
|
||||
by ``uuid``.
|
||||
|
||||
Regression test: SemanticViewDAO previously set ``id_column_name = "uuid"``,
|
||||
which caused find_by_id(pk) to filter on the UUID column using an integer
|
||||
value, always returning None and making every PUT request return 404.
|
||||
"""
|
||||
from superset.daos.semantic_layer import SemanticViewDAO
|
||||
from superset.semantic_layers.models import SemanticView
|
||||
|
||||
view = session_with_semantic_view.query(SemanticView).one()
|
||||
|
||||
# Sanity check: the view has an auto-assigned integer id
|
||||
assert isinstance(view.id, int)
|
||||
|
||||
result = SemanticViewDAO.find_by_id(view.id)
|
||||
|
||||
assert result is not None, (
|
||||
"find_by_id returned None for a valid integer id — "
|
||||
"id_column_name is likely set to 'uuid' instead of 'id'"
|
||||
)
|
||||
assert result.id == view.id
|
||||
assert result.name == "test_view"
|
||||
103
tests/unit_tests/semantic_layers/decorators_test.py
Normal file
103
tests/unit_tests/semantic_layers/decorators_test.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_semantic_layer_stub_raises() -> None:
|
||||
"""The stub decorator raises NotImplementedError before initialization."""
|
||||
import importlib
|
||||
|
||||
import superset_core.semantic_layers.decorators as mod
|
||||
|
||||
# Reload to get the original stub (injection may have replaced it)
|
||||
importlib.reload(mod)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
mod.semantic_layer(id="test", name="Test")
|
||||
|
||||
|
||||
def test_inject_semantic_layer_host_context() -> None:
|
||||
"""The injected decorator registers a class in host context."""
|
||||
from superset.core.api.core_api_injection import (
|
||||
inject_semantic_layer_implementations,
|
||||
)
|
||||
from superset.semantic_layers.registry import registry
|
||||
|
||||
# Clear registry for test isolation
|
||||
registry.clear()
|
||||
|
||||
inject_semantic_layer_implementations()
|
||||
|
||||
import superset_core.semantic_layers.decorators as mod
|
||||
|
||||
# Host context: no extension context active, so no prefix
|
||||
with patch(
|
||||
"superset.extensions.context.get_current_extension_context",
|
||||
return_value=None,
|
||||
):
|
||||
|
||||
@mod.semantic_layer(id="test_layer", name="Test Layer", description="A test")
|
||||
class FakeLayer:
|
||||
pass
|
||||
|
||||
assert "test_layer" in registry
|
||||
assert registry["test_layer"] is FakeLayer
|
||||
assert FakeLayer.name == "Test Layer" # type: ignore[attr-defined]
|
||||
assert FakeLayer.description == "A test" # type: ignore[attr-defined]
|
||||
|
||||
# Cleanup
|
||||
registry.pop("test_layer", None)
|
||||
|
||||
|
||||
def test_inject_semantic_layer_extension_context() -> None:
|
||||
"""The injected decorator prefixes ID in extension context."""
|
||||
from superset.core.api.core_api_injection import (
|
||||
inject_semantic_layer_implementations,
|
||||
)
|
||||
from superset.semantic_layers.registry import registry
|
||||
|
||||
registry.clear()
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.manifest.publisher = "acme"
|
||||
mock_context.manifest.name = "analytics"
|
||||
|
||||
inject_semantic_layer_implementations()
|
||||
|
||||
import superset_core.semantic_layers.decorators as mod
|
||||
|
||||
# Extension context is checked at decorator call time via module lookup
|
||||
with patch(
|
||||
"superset.extensions.context.get_current_extension_context",
|
||||
return_value=mock_context,
|
||||
):
|
||||
|
||||
@mod.semantic_layer(id="ext_layer", name="Extension Layer")
|
||||
class ExtLayer:
|
||||
pass
|
||||
|
||||
expected_id = "extensions.acme.analytics.ext_layer"
|
||||
assert expected_id in registry
|
||||
assert registry[expected_id] is ExtLayer
|
||||
|
||||
# Cleanup
|
||||
registry.pop(expected_id, None)
|
||||
2743
tests/unit_tests/semantic_layers/mapper_test.py
Normal file
2743
tests/unit_tests/semantic_layers/mapper_test.py
Normal file
File diff suppressed because it is too large
Load Diff
712
tests/unit_tests/semantic_layers/models_test.py
Normal file
712
tests/unit_tests/semantic_layers/models_test.py
Normal file
@@ -0,0 +1,712 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for semantic layer models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from superset_core.semantic_layers.types import (
|
||||
Dimension,
|
||||
Grains,
|
||||
Metric,
|
||||
)
|
||||
|
||||
from superset.semantic_layers.models import (
|
||||
ColumnMetadata,
|
||||
get_column_type,
|
||||
MetricMetadata,
|
||||
SemanticLayer,
|
||||
SemanticView,
|
||||
)
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
# =============================================================================
|
||||
# get_column_type tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_get_column_type_temporal_date() -> None:
|
||||
"""Test that date types map to TEMPORAL."""
|
||||
assert get_column_type(pa.date32()) == GenericDataType.TEMPORAL
|
||||
assert get_column_type(pa.date64()) == GenericDataType.TEMPORAL
|
||||
|
||||
|
||||
def test_get_column_type_temporal_timestamp() -> None:
|
||||
"""Test that timestamp types map to TEMPORAL."""
|
||||
assert get_column_type(pa.timestamp("us")) == GenericDataType.TEMPORAL
|
||||
|
||||
|
||||
def test_get_column_type_temporal_time() -> None:
|
||||
"""Test that time types map to TEMPORAL."""
|
||||
assert get_column_type(pa.time64("us")) == GenericDataType.TEMPORAL
|
||||
assert get_column_type(pa.time32("ms")) == GenericDataType.TEMPORAL
|
||||
|
||||
|
||||
def test_get_column_type_numeric_integer() -> None:
|
||||
"""Test that integer types map to NUMERIC."""
|
||||
assert get_column_type(pa.int64()) == GenericDataType.NUMERIC
|
||||
assert get_column_type(pa.int32()) == GenericDataType.NUMERIC
|
||||
|
||||
|
||||
def test_get_column_type_numeric_float() -> None:
|
||||
"""Test that float types map to NUMERIC."""
|
||||
assert get_column_type(pa.float64()) == GenericDataType.NUMERIC
|
||||
|
||||
|
||||
def test_get_column_type_numeric_decimal() -> None:
|
||||
"""Test that decimal types map to NUMERIC."""
|
||||
assert get_column_type(pa.decimal128(38, 10)) == GenericDataType.NUMERIC
|
||||
|
||||
|
||||
def test_get_column_type_numeric_duration() -> None:
|
||||
"""Test that duration types map to NUMERIC."""
|
||||
assert get_column_type(pa.duration("us")) == GenericDataType.NUMERIC
|
||||
|
||||
|
||||
def test_get_column_type_boolean() -> None:
|
||||
"""Test that boolean types map to BOOLEAN."""
|
||||
assert get_column_type(pa.bool_()) == GenericDataType.BOOLEAN
|
||||
|
||||
|
||||
def test_get_column_type_string() -> None:
|
||||
"""Test that string types map to STRING."""
|
||||
assert get_column_type(pa.utf8()) == GenericDataType.STRING
|
||||
assert get_column_type(pa.large_utf8()) == GenericDataType.STRING
|
||||
|
||||
|
||||
def test_get_column_type_binary() -> None:
|
||||
"""Test that binary types map to STRING."""
|
||||
assert get_column_type(pa.binary()) == GenericDataType.STRING
|
||||
|
||||
|
||||
def test_get_column_type_unknown() -> None:
|
||||
"""Test that unknown types default to STRING."""
|
||||
assert get_column_type(pa.null()) == GenericDataType.STRING
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MetricMetadata tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_metric_metadata_required_fields() -> None:
|
||||
"""Test MetricMetadata with required fields only."""
|
||||
metadata = MetricMetadata(
|
||||
metric_name="revenue",
|
||||
expression="SUM(amount)",
|
||||
)
|
||||
assert metadata.metric_name == "revenue"
|
||||
assert metadata.expression == "SUM(amount)"
|
||||
assert metadata.verbose_name is None
|
||||
assert metadata.description is None
|
||||
assert metadata.d3format is None
|
||||
assert metadata.currency is None
|
||||
assert metadata.warning_text is None
|
||||
assert metadata.certified_by is None
|
||||
assert metadata.certification_details is None
|
||||
|
||||
|
||||
def test_metric_metadata_all_fields() -> None:
|
||||
"""Test MetricMetadata with all fields."""
|
||||
metadata = MetricMetadata(
|
||||
metric_name="revenue",
|
||||
expression="SUM(amount)",
|
||||
verbose_name="Total Revenue",
|
||||
description="Sum of all revenue",
|
||||
d3format="$,.2f",
|
||||
currency={"symbol": "$", "symbolPosition": "prefix"},
|
||||
warning_text="Data may be incomplete",
|
||||
certified_by="Data Team",
|
||||
certification_details="Verified Q1 2024",
|
||||
)
|
||||
assert metadata.metric_name == "revenue"
|
||||
assert metadata.expression == "SUM(amount)"
|
||||
assert metadata.verbose_name == "Total Revenue"
|
||||
assert metadata.description == "Sum of all revenue"
|
||||
assert metadata.d3format == "$,.2f"
|
||||
assert metadata.currency == {"symbol": "$", "symbolPosition": "prefix"}
|
||||
assert metadata.warning_text == "Data may be incomplete"
|
||||
assert metadata.certified_by == "Data Team"
|
||||
assert metadata.certification_details == "Verified Q1 2024"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ColumnMetadata tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_column_metadata_required_fields() -> None:
|
||||
"""Test ColumnMetadata with required fields only."""
|
||||
metadata = ColumnMetadata(
|
||||
column_name="order_date",
|
||||
type="DATE",
|
||||
is_dttm=True,
|
||||
)
|
||||
assert metadata.column_name == "order_date"
|
||||
assert metadata.type == "DATE"
|
||||
assert metadata.is_dttm is True
|
||||
assert metadata.verbose_name is None
|
||||
assert metadata.description is None
|
||||
assert metadata.groupby is True
|
||||
assert metadata.filterable is True
|
||||
assert metadata.expression is None
|
||||
assert metadata.python_date_format is None
|
||||
assert metadata.advanced_data_type is None
|
||||
assert metadata.extra is None
|
||||
|
||||
|
||||
def test_column_metadata_all_fields() -> None:
|
||||
"""Test ColumnMetadata with all fields."""
|
||||
metadata = ColumnMetadata(
|
||||
column_name="order_date",
|
||||
type="DATE",
|
||||
is_dttm=True,
|
||||
verbose_name="Order Date",
|
||||
description="Date of the order",
|
||||
groupby=True,
|
||||
filterable=True,
|
||||
expression="DATE(order_timestamp)",
|
||||
python_date_format="%Y-%m-%d",
|
||||
advanced_data_type="date",
|
||||
extra='{"grain": "day"}',
|
||||
)
|
||||
assert metadata.column_name == "order_date"
|
||||
assert metadata.type == "DATE"
|
||||
assert metadata.is_dttm is True
|
||||
assert metadata.verbose_name == "Order Date"
|
||||
assert metadata.description == "Date of the order"
|
||||
assert metadata.groupby is True
|
||||
assert metadata.filterable is True
|
||||
assert metadata.expression == "DATE(order_timestamp)"
|
||||
assert metadata.python_date_format == "%Y-%m-%d"
|
||||
assert metadata.advanced_data_type == "date"
|
||||
assert metadata.extra == '{"grain": "day"}'
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SemanticLayer tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_semantic_layer_repr_with_name() -> None:
|
||||
"""Test SemanticLayer __repr__ with name."""
|
||||
layer = SemanticLayer()
|
||||
layer.name = "My Semantic Layer"
|
||||
layer.uuid = uuid.uuid4()
|
||||
assert repr(layer) == "My Semantic Layer"
|
||||
|
||||
|
||||
def test_semantic_layer_repr_without_name() -> None:
|
||||
"""Test SemanticLayer __repr__ without name (uses uuid)."""
|
||||
layer = SemanticLayer()
|
||||
layer.name = None
|
||||
test_uuid = uuid.uuid4()
|
||||
layer.uuid = test_uuid
|
||||
assert repr(layer) == str(test_uuid)
|
||||
|
||||
|
||||
def test_semantic_layer_implementation_not_implemented() -> None:
|
||||
"""Test that implementation raises KeyError for unregistered type."""
|
||||
layer = SemanticLayer()
|
||||
with pytest.raises(KeyError):
|
||||
_ = layer.implementation
|
||||
|
||||
|
||||
def test_semantic_layer_implementation() -> None:
|
||||
"""Test that implementation returns a configured semantic layer."""
|
||||
layer = SemanticLayer()
|
||||
layer.type = "test_type"
|
||||
layer.configuration = '{"key": "value"}'
|
||||
|
||||
mock_class = MagicMock()
|
||||
mock_impl = MagicMock()
|
||||
mock_class.from_configuration.return_value = mock_impl
|
||||
|
||||
with patch.dict(
|
||||
"superset.semantic_layers.models.registry",
|
||||
{"test_type": mock_class},
|
||||
):
|
||||
# Clear cached property if it exists
|
||||
if "implementation" in layer.__dict__:
|
||||
del layer.__dict__["implementation"]
|
||||
|
||||
result = layer.implementation
|
||||
|
||||
mock_class.from_configuration.assert_called_once_with({"key": "value"})
|
||||
assert result == mock_impl
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SemanticView tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dimensions() -> list[Dimension]:
|
||||
"""Create mock dimensions for testing."""
|
||||
return [
|
||||
Dimension(
|
||||
id="orders.order_date",
|
||||
name="order_date",
|
||||
type=pa.date32(),
|
||||
definition="orders.order_date",
|
||||
description="Date of the order",
|
||||
grain=Grains.DAY,
|
||||
),
|
||||
Dimension(
|
||||
id="products.category",
|
||||
name="category",
|
||||
type=pa.utf8(),
|
||||
definition="products.category",
|
||||
description="Product category",
|
||||
grain=None,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metrics() -> list[Metric]:
|
||||
"""Create mock metrics for testing."""
|
||||
return [
|
||||
Metric(
|
||||
id="orders.revenue",
|
||||
name="revenue",
|
||||
type=pa.float64(),
|
||||
definition="SUM(orders.amount)",
|
||||
description="Total revenue",
|
||||
),
|
||||
Metric(
|
||||
id="orders.count",
|
||||
name="order_count",
|
||||
type=pa.int64(),
|
||||
definition="COUNT(*)",
|
||||
description="Number of orders",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_implementation(
|
||||
mock_dimensions: list[Dimension],
|
||||
mock_metrics: list[Metric],
|
||||
) -> MagicMock:
|
||||
"""Create a mock implementation."""
|
||||
impl = MagicMock()
|
||||
impl.get_dimensions.return_value = mock_dimensions
|
||||
impl.get_metrics.return_value = mock_metrics
|
||||
impl.uid.return_value = "semantic_view_uid_123"
|
||||
return impl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def semantic_view(mock_implementation: MagicMock) -> SemanticView:
|
||||
"""Create a SemanticView with mocked implementation."""
|
||||
view = SemanticView()
|
||||
view.name = "Orders View"
|
||||
view.description = "View of order data"
|
||||
view.uuid = uuid.UUID("12345678-1234-5678-1234-567812345678")
|
||||
view.semantic_layer_uuid = uuid.UUID("87654321-4321-8765-4321-876543218765")
|
||||
view.cache_timeout = 3600
|
||||
view.configuration = "{}"
|
||||
|
||||
# Persist mocked implementation on this instance
|
||||
view.__dict__["implementation"] = mock_implementation
|
||||
|
||||
return view
|
||||
|
||||
|
||||
def test_semantic_view_repr_with_name() -> None:
|
||||
"""Test SemanticView __repr__ with name."""
|
||||
view = SemanticView()
|
||||
view.name = "My View"
|
||||
view.uuid = uuid.uuid4()
|
||||
assert repr(view) == "My View"
|
||||
|
||||
|
||||
def test_semantic_view_repr_without_name() -> None:
|
||||
"""Test SemanticView __repr__ without name (uses uuid)."""
|
||||
view = SemanticView()
|
||||
view.name = None
|
||||
test_uuid = uuid.uuid4()
|
||||
view.uuid = test_uuid
|
||||
assert repr(view) == str(test_uuid)
|
||||
|
||||
|
||||
def test_semantic_view_type() -> None:
|
||||
"""Test SemanticView type property."""
|
||||
view = SemanticView()
|
||||
assert view.type == "semantic_view"
|
||||
|
||||
|
||||
def test_semantic_view_table_name() -> None:
|
||||
"""Test SemanticView table_name property."""
|
||||
view = SemanticView()
|
||||
view.name = "Orders View"
|
||||
assert view.table_name == "Orders View"
|
||||
|
||||
|
||||
def test_semantic_view_kind() -> None:
|
||||
"""Test SemanticView kind property."""
|
||||
view = SemanticView()
|
||||
assert view.kind == "semantic_view"
|
||||
|
||||
|
||||
def test_semantic_view_offset() -> None:
|
||||
"""Test SemanticView offset property."""
|
||||
view = SemanticView()
|
||||
assert view.offset == 0
|
||||
|
||||
|
||||
def test_semantic_view_is_rls_supported() -> None:
|
||||
"""Test SemanticView is_rls_supported property."""
|
||||
view = SemanticView()
|
||||
assert view.is_rls_supported is False
|
||||
|
||||
|
||||
def test_semantic_view_query_language() -> None:
|
||||
"""Test SemanticView query_language property."""
|
||||
view = SemanticView()
|
||||
assert view.query_language is None
|
||||
|
||||
|
||||
def test_semantic_view_get_query_str() -> None:
|
||||
"""Test SemanticView get_query_str method."""
|
||||
view = SemanticView()
|
||||
result = view.get_query_str({})
|
||||
assert result == "Not implemented for semantic layers"
|
||||
|
||||
|
||||
def test_semantic_view_get_extra_cache_keys() -> None:
|
||||
"""Test SemanticView get_extra_cache_keys method."""
|
||||
view = SemanticView()
|
||||
result = view.get_extra_cache_keys({})
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_semantic_view_perm() -> None:
|
||||
"""Test SemanticView perm property."""
|
||||
view = SemanticView()
|
||||
view.uuid = uuid.UUID("12345678-1234-5678-1234-567812345678")
|
||||
view.semantic_layer_uuid = uuid.UUID("87654321-4321-8765-4321-876543218765")
|
||||
assert (
|
||||
view.perm
|
||||
== "87654321432187654321876543218765::12345678123456781234567812345678"
|
||||
)
|
||||
|
||||
|
||||
def test_semantic_view_uid(
|
||||
mock_implementation: MagicMock,
|
||||
mock_dimensions: list[Dimension],
|
||||
mock_metrics: list[Metric],
|
||||
) -> None:
|
||||
"""Test SemanticView uid property."""
|
||||
view = SemanticView()
|
||||
view.name = "Test View"
|
||||
view.uuid = uuid.uuid4()
|
||||
view.semantic_layer_uuid = uuid.uuid4()
|
||||
|
||||
with patch.object(
|
||||
SemanticView,
|
||||
"implementation",
|
||||
new_callable=lambda: property(lambda s: mock_implementation),
|
||||
):
|
||||
assert view.uid == "semantic_view_uid_123"
|
||||
|
||||
|
||||
def test_semantic_view_metrics(
|
||||
mock_implementation: MagicMock,
|
||||
mock_metrics: list[Metric],
|
||||
) -> None:
|
||||
"""Test SemanticView metrics property."""
|
||||
view = SemanticView()
|
||||
|
||||
with patch.object(
|
||||
SemanticView,
|
||||
"implementation",
|
||||
new_callable=lambda: property(lambda s: mock_implementation),
|
||||
):
|
||||
metrics = view.metrics
|
||||
assert len(metrics) == 2
|
||||
assert metrics[0].metric_name == "revenue"
|
||||
assert metrics[0].expression == "SUM(orders.amount)"
|
||||
assert metrics[0].description == "Total revenue"
|
||||
assert metrics[1].metric_name == "order_count"
|
||||
|
||||
|
||||
def test_semantic_view_columns(
|
||||
mock_implementation: MagicMock,
|
||||
mock_dimensions: list[Dimension],
|
||||
) -> None:
|
||||
"""Test SemanticView columns property."""
|
||||
view = SemanticView()
|
||||
|
||||
with patch.object(
|
||||
SemanticView,
|
||||
"implementation",
|
||||
new_callable=lambda: property(lambda s: mock_implementation),
|
||||
):
|
||||
columns = view.columns
|
||||
assert len(columns) == 2
|
||||
assert columns[0].column_name == "order_date"
|
||||
assert columns[0].type == "date32[day]"
|
||||
assert columns[0].is_dttm is True
|
||||
assert columns[0].description == "Date of the order"
|
||||
assert columns[1].column_name == "category"
|
||||
assert columns[1].type == "string"
|
||||
assert columns[1].is_dttm is False
|
||||
|
||||
|
||||
def test_semantic_view_column_names(
|
||||
mock_implementation: MagicMock,
|
||||
mock_dimensions: list[Dimension],
|
||||
) -> None:
|
||||
"""Test SemanticView column_names property."""
|
||||
view = SemanticView()
|
||||
|
||||
with patch.object(
|
||||
SemanticView,
|
||||
"implementation",
|
||||
new_callable=lambda: property(lambda s: mock_implementation),
|
||||
):
|
||||
column_names = view.column_names
|
||||
assert column_names == ["order_date", "category"]
|
||||
|
||||
|
||||
def test_semantic_view_get_time_grains(
|
||||
mock_implementation: MagicMock,
|
||||
mock_dimensions: list[Dimension],
|
||||
) -> None:
|
||||
"""Test SemanticView get_time_grains property."""
|
||||
view = SemanticView()
|
||||
|
||||
with patch.object(
|
||||
SemanticView,
|
||||
"implementation",
|
||||
new_callable=lambda: property(lambda s: mock_implementation),
|
||||
):
|
||||
time_grains = view.get_time_grains()
|
||||
assert len(time_grains) == 1
|
||||
assert time_grains[0]["name"] == "Day"
|
||||
assert time_grains[0]["duration"] == "P1D"
|
||||
|
||||
|
||||
def test_semantic_view_has_drill_by_columns_all_exist(
|
||||
mock_implementation: MagicMock,
|
||||
mock_dimensions: list[Dimension],
|
||||
) -> None:
|
||||
"""Test has_drill_by_columns when all columns exist."""
|
||||
view = SemanticView()
|
||||
|
||||
with patch.object(
|
||||
SemanticView,
|
||||
"implementation",
|
||||
new_callable=lambda: property(lambda s: mock_implementation),
|
||||
):
|
||||
assert view.has_drill_by_columns(["order_date", "category"]) is True
|
||||
|
||||
|
||||
def test_semantic_view_has_drill_by_columns_some_missing(
|
||||
mock_implementation: MagicMock,
|
||||
mock_dimensions: list[Dimension],
|
||||
) -> None:
|
||||
"""Test has_drill_by_columns when some columns are missing."""
|
||||
view = SemanticView()
|
||||
|
||||
with patch.object(
|
||||
SemanticView,
|
||||
"implementation",
|
||||
new_callable=lambda: property(lambda s: mock_implementation),
|
||||
):
|
||||
assert view.has_drill_by_columns(["order_date", "nonexistent"]) is False
|
||||
|
||||
|
||||
def test_semantic_view_has_drill_by_columns_empty(
|
||||
mock_implementation: MagicMock,
|
||||
mock_dimensions: list[Dimension],
|
||||
) -> None:
|
||||
"""Test has_drill_by_columns with empty list."""
|
||||
view = SemanticView()
|
||||
|
||||
with patch.object(
|
||||
SemanticView,
|
||||
"implementation",
|
||||
new_callable=lambda: property(lambda s: mock_implementation),
|
||||
):
|
||||
assert view.has_drill_by_columns([]) is True
|
||||
|
||||
|
||||
def test_semantic_view_data(
|
||||
mock_implementation: MagicMock,
|
||||
mock_dimensions: list[Dimension],
|
||||
mock_metrics: list[Metric],
|
||||
) -> None:
|
||||
"""Test SemanticView data property."""
|
||||
view = SemanticView()
|
||||
view.name = "Orders View"
|
||||
view.description = "View of order data"
|
||||
view.id = 1
|
||||
view.uuid = uuid.UUID("12345678-1234-5678-1234-567812345678")
|
||||
view.semantic_layer_uuid = uuid.UUID("87654321-4321-8765-4321-876543218765")
|
||||
view.cache_timeout = 3600
|
||||
|
||||
with patch.object(
|
||||
SemanticView,
|
||||
"implementation",
|
||||
new_callable=lambda: property(lambda s: mock_implementation),
|
||||
):
|
||||
data = view.data
|
||||
|
||||
# Check core fields
|
||||
assert data["id"] == 1
|
||||
assert data["uid"] == "semantic_view_uid_123"
|
||||
assert data["type"] == "semantic_view"
|
||||
assert data["name"] == "Orders View"
|
||||
assert data["description"] == "View of order data"
|
||||
assert data["cache_timeout"] == 3600
|
||||
|
||||
# Check columns
|
||||
assert len(data["columns"]) == 2
|
||||
assert data["columns"][0]["column_name"] == "order_date"
|
||||
assert data["columns"][0]["type"] == "date32[day]"
|
||||
assert data["columns"][0]["is_dttm"] is True
|
||||
assert data["columns"][0]["type_generic"] == GenericDataType.TEMPORAL
|
||||
assert data["columns"][1]["column_name"] == "category"
|
||||
assert data["columns"][1]["type"] == "string"
|
||||
assert data["columns"][1]["type_generic"] == GenericDataType.STRING
|
||||
|
||||
# Check metrics
|
||||
assert len(data["metrics"]) == 2
|
||||
assert data["metrics"][0]["metric_name"] == "revenue"
|
||||
assert data["metrics"][0]["expression"] == "SUM(orders.amount)"
|
||||
assert data["metrics"][1]["metric_name"] == "order_count"
|
||||
|
||||
# Check column_types and column_names
|
||||
assert data["column_types"] == [
|
||||
GenericDataType.TEMPORAL,
|
||||
GenericDataType.STRING,
|
||||
]
|
||||
assert data["column_names"] == ["order_date", "category"]
|
||||
|
||||
# Check other fields
|
||||
assert data["table_name"] == "Orders View"
|
||||
assert data["datasource_name"] == "Orders View"
|
||||
assert data["offset"] == 0
|
||||
|
||||
|
||||
def test_semantic_view_get_query_result(
|
||||
mock_implementation: MagicMock,
|
||||
) -> None:
|
||||
"""Test SemanticView get_query_result method."""
|
||||
view = SemanticView()
|
||||
|
||||
mock_query_object = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
|
||||
with patch(
|
||||
"superset.semantic_layers.models.get_results",
|
||||
return_value=mock_result,
|
||||
) as mock_get_results:
|
||||
result = view.get_query_result(mock_query_object)
|
||||
|
||||
mock_get_results.assert_called_once_with(mock_query_object)
|
||||
assert result == mock_result
|
||||
|
||||
|
||||
def test_semantic_view_data_for_slices(
|
||||
mock_implementation: MagicMock,
|
||||
mock_dimensions: list[Dimension],
|
||||
mock_metrics: list[Metric],
|
||||
) -> None:
|
||||
"""Test SemanticView data_for_slices returns same as data."""
|
||||
view = SemanticView()
|
||||
view.name = "Orders View"
|
||||
view.description = "View of order data"
|
||||
view.uuid = uuid.UUID("12345678-1234-5678-1234-567812345678")
|
||||
view.semantic_layer_uuid = uuid.UUID("87654321-4321-8765-4321-876543218765")
|
||||
view.cache_timeout = 3600
|
||||
|
||||
with patch.object(
|
||||
SemanticView,
|
||||
"implementation",
|
||||
new_callable=lambda: property(lambda s: mock_implementation),
|
||||
):
|
||||
assert view.data_for_slices([]) == view.data
|
||||
|
||||
|
||||
def test_semantic_view_catalog_perm() -> None:
|
||||
"""Test SemanticView catalog_perm returns None."""
|
||||
view = SemanticView()
|
||||
assert view.catalog_perm is None
|
||||
|
||||
|
||||
def test_semantic_view_schema_perm() -> None:
|
||||
"""Test SemanticView schema_perm returns None."""
|
||||
view = SemanticView()
|
||||
assert view.schema_perm is None
|
||||
|
||||
|
||||
def test_semantic_view_schema() -> None:
|
||||
"""Test SemanticView schema returns None."""
|
||||
view = SemanticView()
|
||||
assert view.schema is None
|
||||
|
||||
|
||||
def test_semantic_view_url() -> None:
|
||||
"""Test SemanticView url property."""
|
||||
view = SemanticView()
|
||||
view.uuid = uuid.UUID("12345678-1234-5678-1234-567812345678")
|
||||
assert view.url == "/semantic_view/12345678-1234-5678-1234-567812345678/"
|
||||
|
||||
|
||||
def test_semantic_view_explore_url() -> None:
|
||||
"""Test SemanticView explore_url property."""
|
||||
view = SemanticView()
|
||||
view.id = 42
|
||||
assert (
|
||||
view.explore_url == "/explore/?datasource_type=semantic_view&datasource_id=42"
|
||||
)
|
||||
|
||||
|
||||
def test_semantic_view_implementation() -> None:
|
||||
"""Test SemanticView implementation property."""
|
||||
view = SemanticView()
|
||||
view.name = "Test View"
|
||||
view.configuration = '{"key": "value"}'
|
||||
|
||||
mock_semantic_layer = MagicMock()
|
||||
mock_semantic_view_impl = MagicMock()
|
||||
mock_semantic_layer.implementation.get_semantic_view.return_value = (
|
||||
mock_semantic_view_impl
|
||||
)
|
||||
view.semantic_layer = mock_semantic_layer
|
||||
|
||||
# Clear cached property if it exists
|
||||
if "implementation" in view.__dict__:
|
||||
del view.__dict__["implementation"]
|
||||
|
||||
result = view.implementation
|
||||
|
||||
mock_semantic_layer.implementation.get_semantic_view.assert_called_once_with(
|
||||
"Test View",
|
||||
{"key": "value"},
|
||||
)
|
||||
assert result == mock_semantic_view_impl
|
||||
208
tests/unit_tests/semantic_layers/schemas_test.py
Normal file
208
tests/unit_tests/semantic_layers/schemas_test.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import pytest
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset.semantic_layers.schemas import (
|
||||
SemanticLayerPostSchema,
|
||||
SemanticLayerPutSchema,
|
||||
SemanticViewPutSchema,
|
||||
)
|
||||
|
||||
|
||||
def test_semantic_view_put_schema_both_fields() -> None:
|
||||
"""Test loading both description and cache_timeout."""
|
||||
schema = SemanticViewPutSchema()
|
||||
result = schema.load({"description": "A description", "cache_timeout": 300})
|
||||
assert result == {"description": "A description", "cache_timeout": 300}
|
||||
|
||||
|
||||
def test_semantic_view_put_schema_description_only() -> None:
|
||||
"""Test loading with only description."""
|
||||
schema = SemanticViewPutSchema()
|
||||
result = schema.load({"description": "Just a description"})
|
||||
assert result == {"description": "Just a description"}
|
||||
|
||||
|
||||
def test_semantic_view_put_schema_cache_timeout_only() -> None:
|
||||
"""Test loading with only cache_timeout."""
|
||||
schema = SemanticViewPutSchema()
|
||||
result = schema.load({"cache_timeout": 600})
|
||||
assert result == {"cache_timeout": 600}
|
||||
|
||||
|
||||
def test_semantic_view_put_schema_empty() -> None:
|
||||
"""Test loading empty payload."""
|
||||
schema = SemanticViewPutSchema()
|
||||
result = schema.load({})
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_semantic_view_put_schema_null_description() -> None:
|
||||
"""Test that description accepts None."""
|
||||
schema = SemanticViewPutSchema()
|
||||
result = schema.load({"description": None})
|
||||
assert result == {"description": None}
|
||||
|
||||
|
||||
def test_semantic_view_put_schema_null_cache_timeout() -> None:
|
||||
"""Test that cache_timeout accepts None."""
|
||||
schema = SemanticViewPutSchema()
|
||||
result = schema.load({"cache_timeout": None})
|
||||
assert result == {"cache_timeout": None}
|
||||
|
||||
|
||||
def test_semantic_view_put_schema_invalid_cache_timeout() -> None:
|
||||
"""Test that non-integer cache_timeout raises ValidationError."""
|
||||
schema = SemanticViewPutSchema()
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load({"cache_timeout": "not_a_number"})
|
||||
assert "cache_timeout" in exc_info.value.messages
|
||||
|
||||
|
||||
def test_semantic_view_put_schema_unknown_field() -> None:
|
||||
"""Test that unknown fields raise ValidationError."""
|
||||
schema = SemanticViewPutSchema()
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load({"unknown_field": "value"})
|
||||
assert "unknown_field" in exc_info.value.messages
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SemanticLayerPostSchema tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_post_schema_all_fields() -> None:
|
||||
"""Test loading all fields."""
|
||||
schema = SemanticLayerPostSchema()
|
||||
result = schema.load(
|
||||
{
|
||||
"name": "My Layer",
|
||||
"description": "A layer",
|
||||
"type": "snowflake",
|
||||
"configuration": {"account": "test"},
|
||||
"cache_timeout": 300,
|
||||
}
|
||||
)
|
||||
assert result["name"] == "My Layer"
|
||||
assert result["type"] == "snowflake"
|
||||
assert result["configuration"] == {"account": "test"}
|
||||
assert result["cache_timeout"] == 300
|
||||
|
||||
|
||||
def test_post_schema_required_fields_only() -> None:
|
||||
"""Test loading with only required fields."""
|
||||
schema = SemanticLayerPostSchema()
|
||||
result = schema.load(
|
||||
{
|
||||
"name": "My Layer",
|
||||
"type": "snowflake",
|
||||
"configuration": {"account": "test"},
|
||||
}
|
||||
)
|
||||
assert result["name"] == "My Layer"
|
||||
assert "description" not in result
|
||||
assert "cache_timeout" not in result
|
||||
|
||||
|
||||
def test_post_schema_missing_name() -> None:
|
||||
"""Test that missing name raises ValidationError."""
|
||||
schema = SemanticLayerPostSchema()
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load({"type": "snowflake", "configuration": {}})
|
||||
assert "name" in exc_info.value.messages
|
||||
|
||||
|
||||
def test_post_schema_missing_type() -> None:
|
||||
"""Test that missing type raises ValidationError."""
|
||||
schema = SemanticLayerPostSchema()
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load({"name": "My Layer", "configuration": {}})
|
||||
assert "type" in exc_info.value.messages
|
||||
|
||||
|
||||
def test_post_schema_missing_configuration() -> None:
|
||||
"""Test that missing configuration raises ValidationError."""
|
||||
schema = SemanticLayerPostSchema()
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load({"name": "My Layer", "type": "snowflake"})
|
||||
assert "configuration" in exc_info.value.messages
|
||||
|
||||
|
||||
def test_post_schema_null_description() -> None:
|
||||
"""Test that description accepts None."""
|
||||
schema = SemanticLayerPostSchema()
|
||||
result = schema.load(
|
||||
{
|
||||
"name": "My Layer",
|
||||
"type": "snowflake",
|
||||
"configuration": {},
|
||||
"description": None,
|
||||
}
|
||||
)
|
||||
assert result["description"] is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SemanticLayerPutSchema tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_put_schema_all_fields() -> None:
|
||||
"""Test loading all fields."""
|
||||
schema = SemanticLayerPutSchema()
|
||||
result = schema.load(
|
||||
{
|
||||
"name": "Updated",
|
||||
"description": "New desc",
|
||||
"configuration": {"account": "new"},
|
||||
"cache_timeout": 600,
|
||||
}
|
||||
)
|
||||
assert result["name"] == "Updated"
|
||||
assert result["configuration"] == {"account": "new"}
|
||||
|
||||
|
||||
def test_put_schema_empty() -> None:
|
||||
"""Test loading empty payload."""
|
||||
schema = SemanticLayerPutSchema()
|
||||
result = schema.load({})
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_put_schema_name_only() -> None:
|
||||
"""Test loading with only name."""
|
||||
schema = SemanticLayerPutSchema()
|
||||
result = schema.load({"name": "New Name"})
|
||||
assert result == {"name": "New Name"}
|
||||
|
||||
|
||||
def test_put_schema_configuration_only() -> None:
|
||||
"""Test loading with only configuration."""
|
||||
schema = SemanticLayerPutSchema()
|
||||
result = schema.load({"configuration": {"key": "value"}})
|
||||
assert result == {"configuration": {"key": "value"}}
|
||||
|
||||
|
||||
def test_put_schema_unknown_field() -> None:
|
||||
"""Test that unknown fields raise ValidationError."""
|
||||
schema = SemanticLayerPutSchema()
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
schema.load({"unknown_field": "value"})
|
||||
assert "unknown_field" in exc_info.value.messages
|
||||
Reference in New Issue
Block a user