Compare commits

...

47 Commits

Author SHA1 Message Date
Beto Dealmeida
7e917a18e7 WIP 2025-11-07 14:03:39 -05:00
Beto Dealmeida
76fd626d57 Working on exporable 2025-11-05 16:59:06 -05:00
Beto Dealmeida
df2cad1aed WIP 2025-11-05 15:34:00 -05:00
Beto Dealmeida
278e982ab0 WIP 2025-11-05 13:19:18 -05:00
Beto Dealmeida
e6beebc88c WIP 2025-11-05 13:19:12 -05:00
Beto Dealmeida
8a9cca546b DAOS/commands 2025-11-04 15:19:34 -05:00
Beto Dealmeida
9997cdeb62 SQLAlchemy models 2025-11-04 11:41:21 -05:00
Beto Dealmeida
ff20d991ab Fix lint 2025-11-04 10:36:38 -05:00
Beto Dealmeida
97a0eb5ffa Cleaning up 2025-10-30 18:24:59 -04:00
Beto Dealmeida
3891cfeeb3 Split tests 2025-10-30 14:58:22 -04:00
Beto Dealmeida
97e52d9485 Split snowflake 2025-10-30 13:54:57 -04:00
Beto Dealmeida
debcde2057 Rearrange mapper 2025-10-29 15:41:14 -04:00
Beto Dealmeida
f6be0b4dea Fix tests 2025-10-29 15:36:30 -04:00
Beto Dealmeida
836dddafc6 Add tests 2025-10-29 15:21:31 -04:00
Beto Dealmeida
fb39bcbde3 Return with queries 2025-10-29 14:35:08 -04:00
Beto Dealmeida
0348fe93bd QueryObject to df 2025-10-29 14:25:45 -04:00
Beto Dealmeida
9215d3f064 Unit tests for Snowflake SL 2025-10-28 22:00:12 -04:00
Beto Dealmeida
4aa4985562 WIP 2025-10-28 18:50:05 -04:00
Beto Dealmeida
d07e209a9d WIP 2025-10-28 15:11:33 -04:00
Beto Dealmeida
29e335aa3e WIP 2025-10-28 15:08:35 -04:00
Beto Dealmeida
e3dec47a5e More Snowflake tests 2025-10-27 16:25:20 -04:00
Beto Dealmeida
0cc1f46516 WIP Snowflake tests 2025-10-27 14:56:52 -04:00
Beto Dealmeida
7aa9c63b66 More tests 2025-10-27 13:42:47 -04:00
Beto Dealmeida
5ba6db46a7 Add unit tests 2025-10-27 12:50:06 -04:00
Beto Dealmeida
d525b05d71 WIP fix tuple 2025-10-23 12:20:28 -04:00
Beto Dealmeida
d14bcba501 WIP filters inner 2025-10-22 17:04:32 -04:00
Beto Dealmeida
15d286aacf WIP filters 2025-10-22 16:59:18 -04:00
Beto Dealmeida
9e16d111fb WIP 2025-10-22 16:43:26 -04:00
Beto Dealmeida
297bd1e732 Add protocols 2025-10-22 10:10:47 -04:00
Beto Dealmeida
bfa930a3ac Adhoc order by 2025-10-21 15:41:28 -04:00
Beto Dealmeida
befcf96027 WIP 2025-10-20 19:00:58 -04:00
Beto Dealmeida
6f6567d5c9 Working on mapper 2025-10-20 18:38:46 -04:00
Beto Dealmeida
3fb58b996a Improve response to include queries 2025-10-20 18:38:46 -04:00
Beto Dealmeida
ffae2063e2 Fix lint 2025-10-20 18:38:46 -04:00
Beto Dealmeida
e1899f1014 GroupFilter 2025-10-20 18:38:46 -04:00
Beto Dealmeida
dfc6aad5f0 Working on GroupFilter 2025-10-20 18:38:46 -04:00
Beto Dealmeida
837ea2a07f Improving dataframe method 2025-10-20 18:38:46 -04:00
Beto Dealmeida
b83596893a WIP 2025-10-20 18:38:46 -04:00
Beto Dealmeida
a7e446d2ff Dynamic configuration 2025-10-20 18:38:46 -04:00
Beto Dealmeida
ccbdc2359e Add docs 2025-10-20 18:38:46 -04:00
Beto Dealmeida
7e40403287 Adding get_dataframe 2025-10-20 18:38:45 -04:00
Beto Dealmeida
4d83840f81 Working on get_values 2025-10-20 18:38:45 -04:00
Beto Dealmeida
4c77a527c5 Add types 2025-10-20 18:38:45 -04:00
Beto Dealmeida
52b1530666 WIP 2025-10-20 18:38:45 -04:00
Beto Dealmeida
70fd9ff617 WIP 2025-10-20 18:38:45 -04:00
Beto Dealmeida
ae415b93d5 WIP 2025-10-20 18:38:45 -04:00
Beto Dealmeida
15bfab6b1e WIP 2025-10-20 18:38:45 -04:00
43 changed files with 11448 additions and 1 deletions

359
chart-request-flow.md Normal file
View File

@@ -0,0 +1,359 @@
# Chart Data Request Flow in Apache Superset
This document traces the complete path of a chart data request through the Superset backend, from API endpoint to database query and back.
## Overview
When a client requests chart data (e.g., loading a histogram chart), the request flows through multiple layers:
1. API Endpoint
2. Schema Validation/Parsing
3. Command Pattern (Business Logic)
4. Query Context Processing
5. Database Execution
6. Post-Processing
7. Response Formatting
## Detailed Flow
### 1. Entry Point: API Endpoint
**File**: `superset/charts/data/api.py:187`
**Endpoint**: `POST /api/v1/chart/data`
The request hits `ChartDataRestApi.data()` method which:
- Parses the JSON body from the request
- Creates a `QueryContext` object from the form data via `ChartDataQueryContextSchema`
- Creates a `ChartDataCommand` to execute the query
- Validates and executes the command
```python
def data(self) -> Response:
json_body = request.json
query_context = self._create_query_context_from_form(json_body)
command = ChartDataCommand(query_context)
command.validate()
return self._get_data_response(command, ...)
```
### 2. Schema Layer: Request Parsing
**File**: `superset/charts/schemas.py:1384`
`ChartDataQueryContextSchema.load()` deserializes the request into:
**QueryContext object** (the main container):
- datasource: Database table/query info
- queries: List of query objects
- result_format: JSON/CSV/XLSX
- result_type: FULL/SAMPLES/QUERY/etc
- force: Whether to bypass cache
**List of QueryObject instances** (one per query in the request):
- columns: Columns to select (e.g., ["age"])
- metrics: Aggregations to compute
- filters: WHERE clause filters
- post_processing: Client-side transformations (e.g., histogram with bins=25)
### 3. Command Pattern: Business Logic
**File**: `superset/commands/chart/data/get_data_command.py:39`
`ChartDataCommand.run()` orchestrates the execution:
```python
def run(self, **kwargs: Any) -> dict[str, Any]:
payload = self._query_context.get_payload(
cache_query_context=cache_query_context,
force_cached=force_cached
)
for query in payload["queries"]:
if query.get("error"):
raise ChartDataQueryFailedError(query["error"])
return {
"query_context": self._query_context,
"queries": payload["queries"]
}
```
### 4. Query Context Processor: Core Execution
**File**: `superset/common/query_context_processor.py:1052`
`QueryContextProcessor.get_payload()`:
- Iterates through each `QueryObject` in `query_context.queries`
- For each query, calls `get_query_results()` which routes based on result_type:
- `FULL``_get_full()``get_df_payload()`
- `SAMPLES``_get_samples()`
- `QUERY``_get_query()`
**File**: `superset/common/query_context_processor.py:128`
`QueryContextProcessor.get_df_payload()`:
1. **Generate cache key** from query object
2. **Check cache** using `QueryCacheManager`
3. **If cache miss**:
- Validate columns exist in datasource
- Call `get_query_result(query_obj)` to execute SQL
- Get annotation data if needed
- Cache the result with appropriate timeout
4. **Return payload** with DataFrame and metadata
```python
def get_df_payload(self, query_obj, force_cached=False):
cache_key = self.query_cache_key(query_obj)
timeout = self.get_cache_timeout()
cache = QueryCacheManager.get(key=cache_key, ...)
if not cache.is_loaded:
query_result = self.get_query_result(query_obj)
annotation_data = self.get_annotation_data(query_obj)
cache.set_query_result(...)
return {
"cache_key": cache_key,
"df": cache.df,
"query": cache.query,
"is_cached": cache.is_cached,
...
}
```
### 5. Database Query Execution
**File**: `superset/common/query_context_processor.py:267`
`QueryContextProcessor.get_query_result()`:
```python
def get_query_result(self, query_object: QueryObject) -> QueryResult:
# Execute SQL query on the datasource
result = query_context.datasource.query(query_object.to_dict())
df = result.df
# Normalize timestamps to pandas datetime format
if not df.empty:
df = self.normalize_df(df, query_object)
# Handle time offset comparisons if specified
if query_object.time_offsets:
time_offsets = self.processing_time_offsets(df, query_object)
df = time_offsets["df"]
# Apply post-processing operations
df = query_object.exec_post_processing(df)
result.df = df
return result
```
The `datasource.query()` call goes to your database connector (e.g., `SqlaTable.query()`) which:
- Converts the QueryObject dict to SQL using SQLAlchemy
- Executes the query via database engine
- Returns a `QueryResult` with a pandas DataFrame
### 6. Post-Processing
**File**: `superset/common/query_object.py:484`
`QueryObject.exec_post_processing()`:
- Applies operations from `post_processing` list in sequence
- Each operation is a pandas transformation (e.g., pivot, aggregate, histogram)
- Uses functions from `superset.utils.pandas_postprocessing`
Example for histogram:
```python
def exec_post_processing(self, df: DataFrame) -> DataFrame:
for post_process in self.post_processing:
operation = post_process.get("operation") # "histogram"
options = post_process.get("options", {}) # {column: "age", bins: 25}
df = getattr(pandas_postprocessing, operation)(df, **options)
return df
```
### 7. Response Formatting
**File**: `superset/charts/data/api.py:346`
`ChartDataRestApi._send_chart_response()`:
- Takes the result dict from command
- Formats based on `result_format`:
- **JSON**: Converts DataFrame to list of dicts
- **CSV**: Converts to CSV string
- **XLSX**: Converts to Excel binary
- Returns Flask Response with appropriate headers
```python
def _send_chart_response(self, result, form_data=None, datasource=None):
result_format = result["query_context"].result_format
if result_format == ChartDataResultFormat.JSON:
queries = result["queries"]
response_data = json.dumps(
{"result": queries},
default=json.json_int_dttm_ser,
ignore_nan=True,
)
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
return resp
```
## Key Objects and Data Structures
### QueryContext
**File**: `superset/common/query_context.py:41`
The main container for a chart data request.
```python
{
datasource: BaseDatasource, # Dataset (e.g., id=19, type="table")
queries: list[QueryObject], # List of queries to execute
result_type: ChartDataResultType, # "full", "samples", "query", etc.
result_format: ChartDataResultFormat, # "json", "csv", "xlsx"
force: bool, # Bypass cache flag
form_data: dict, # Original form_data from client
custom_cache_timeout: int | None # Override cache timeout
}
```
### QueryObject
**File**: `superset/common/query_object.py:79`
Represents a single database query.
```python
{
columns: list[Column], # Columns to select ["age"]
metrics: list[Metric] | None, # Aggregations to compute
filters: list[FilterClause], # WHERE clause filters
extras: dict[str, Any], # Additional query options
post_processing: list[dict], # Client-side transformations
row_limit: int | None, # LIMIT clause
row_offset: int, # OFFSET clause
order_desc: bool, # Sort direction
time_range: str | None, # Time filter range
granularity: str | None, # Temporal grouping column
annotation_layers: list[dict], # Annotations to overlay
from_dttm: datetime | None, # Computed time range start
to_dttm: datetime | None # Computed time range end
}
```
### QueryResult
**File**: `superset/models/helpers.py`
Returned from `datasource.query()`.
```python
{
df: pd.DataFrame, # The data from database
query: str, # Executed SQL query
from_dttm: datetime, # Time range start
to_dttm: datetime, # Time range end
error: str | None, # Error message if failed
status: QueryStatus # success, failed, etc.
}
```
## Example Request Flow
For a histogram chart request like:
```bash
curl 'https://example.com/api/v1/chart/data' \
-H 'content-type: application/json' \
--data-raw '{
"datasource":{"id":19,"type":"table"},
"queries":[{
"columns":["age"],
"filters":[{
"col":"time_start",
"op":"TEMPORAL_RANGE",
"val":"No filter"
}],
"row_limit":10000,
"post_processing":[{
"operation":"histogram",
"options":{"column":"age","bins":25}
}]
}],
"result_format":"json",
"result_type":"full"
}'
```
### Flow Summary
```
Client Request (curl)
ChartDataRestApi.data()
↓ (parses JSON)
ChartDataQueryContextSchema.load()
↓ (creates objects)
QueryContext + [QueryObject]
ChartDataCommand.run()
QueryContextProcessor.get_payload()
↓ (for each QueryObject)
get_query_results() → _get_full()
get_df_payload()
├→ Check Cache (QueryCacheManager)
└→ get_query_result()
├→ datasource.query() → Build SQL → Execute → pandas DataFrame
├→ normalize_df() → Timestamp normalization
└→ exec_post_processing() → Apply histogram operation
Return payload {df, query, metadata}
_send_chart_response()
↓ (format as JSON)
Flask Response → Client
```
## Architecture Patterns
The codebase follows clean separation of concerns:
1. **API Layer** (`superset/charts/data/api.py`): Handles HTTP requests/responses
2. **Schema Layer** (`superset/charts/schemas.py`): Validates and deserializes input
3. **Command Layer** (`superset/commands/`): Orchestrates business logic
4. **Query Context/Processor** (`superset/common/`): Manages execution and caching
5. **Query Object**: Represents individual database queries
6. **Datasource Layer** (`superset/connectors/`): Database abstraction and SQL generation
### Key Benefits
- **Caching**: Results cached at multiple levels (query result, query context)
- **Security**: Access control enforced via `raise_for_access()`
- **Flexibility**: Supports multiple result types and formats
- **Post-processing**: Client-side transformations without re-querying database
- **Time Comparison**: Built-in support for time offset queries
- **Annotations**: Overlay additional data layers on charts
## Caching Strategy
**File**: `superset/common/utils/query_cache_manager.py`
Cache keys are generated from:
- Query object (columns, metrics, filters, etc.)
- Datasource UID
- RLS (Row Level Security) rules
- User context (if per-user caching enabled)
- Time range (using relative time strings, not absolute timestamps)
This ensures that:
- Same query returns cached results
- Different users see appropriate cached data
- Time-relative queries (e.g., "Last 7 days") cache correctly

View File

@@ -19,6 +19,7 @@
# Import all settings from the main config first
from flask_caching.backends.filesystemcache import FileSystemCache
from superset_config import * # noqa: F403
# Override caching to use simple in-memory cache instead of Redis

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -64148,7 +64148,7 @@
"reselect": "^5.1.1",
"rison": "^0.1.1",
"seedrandom": "^3.0.5",
"xss": "^1.0.14"
"xss": "^1.0.15"
},
"devDependencies": {
"@emotion/styled": "^11.14.1",

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Create semantic layer command."""
from __future__ import annotations
from functools import partial
from typing import Any
from flask_appbuilder.models.sqla import Model
from marshmallow.validate import ValidationError
from superset.commands.base import BaseCommand
from superset.commands.semantic_layer.exceptions import (
SemanticLayerCreateFailedError,
SemanticLayerExistsValidationError,
SemanticLayerInvalidError,
SemanticLayerRequiredFieldValidationError,
)
from superset.daos.semantic_layer import SemanticLayerDAO
from superset.utils.decorators import on_error, transaction
class CreateSemanticLayerCommand(BaseCommand):
"""Command to create a semantic layer."""
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=SemanticLayerCreateFailedError))
def run(self) -> Model:
"""
Create a semantic layer.
:return: The created semantic layer
"""
self.validate()
return SemanticLayerDAO.create(attributes=self._properties)
def validate(self) -> None:
"""
Validate the semantic layer data.
:raises SemanticLayerInvalidError: If validation fails
"""
exceptions: list[ValidationError] = []
# Validate required fields
if not self._properties.get("name"):
exceptions.append(SemanticLayerRequiredFieldValidationError("name"))
if not self._properties.get("type"):
exceptions.append(SemanticLayerRequiredFieldValidationError("type"))
# Validate uniqueness
name = self._properties.get("name")
if name and not SemanticLayerDAO.validate_uniqueness(name):
exceptions.append(SemanticLayerExistsValidationError())
if exceptions:
exception = SemanticLayerInvalidError()
exception.extend(exceptions)
raise exception

View File

@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Delete semantic layer command."""
from __future__ import annotations
from functools import partial
from superset.commands.base import BaseCommand
from superset.commands.semantic_layer.exceptions import (
SemanticLayerDeleteFailedError,
SemanticLayerNotFoundError,
)
from superset.daos.semantic_layer import SemanticLayerDAO
from superset.semantic_layers.models import SemanticLayer
from superset.utils.decorators import on_error, transaction
class DeleteSemanticLayerCommand(BaseCommand):
"""Command to delete a semantic layer."""
def __init__(self, model_id: str):
self._model_id = model_id
self._model: SemanticLayer | None = None
@transaction(on_error=partial(on_error, reraise=SemanticLayerDeleteFailedError))
def run(self) -> None:
"""
Delete a semantic layer.
Semantic views will be cascade deleted.
"""
self.validate()
assert self._model
SemanticLayerDAO.delete([self._model])
def validate(self) -> None:
"""
Validate the semantic layer deletion.
:raises SemanticLayerNotFoundError: If semantic layer not found
"""
self._model = SemanticLayerDAO.find_by_id(self._model_id)
if not self._model:
raise SemanticLayerNotFoundError()

View File

@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Exceptions for semantic layer commands."""
from flask_babel import lazy_gettext as _
from marshmallow.validate import ValidationError
from superset.commands.exceptions import (
CommandInvalidError,
CreateFailedError,
DeleteFailedError,
ObjectNotFoundError,
UpdateFailedError,
)
class SemanticLayerInvalidError(CommandInvalidError):
"""Semantic layer parameters are invalid."""
message = _("Semantic layer parameters are invalid.")
class SemanticLayerNotFoundError(ObjectNotFoundError):
"""Semantic layer not found."""
def __init__(self) -> None:
super().__init__("Semantic layer", None)
class SemanticLayerCreateFailedError(CreateFailedError):
"""Semantic layer could not be created."""
message = _("Semantic layer could not be created.")
class SemanticLayerUpdateFailedError(UpdateFailedError):
"""Semantic layer could not be updated."""
message = _("Semantic layer could not be updated.")
class SemanticLayerDeleteFailedError(DeleteFailedError):
"""Semantic layer could not be deleted."""
message = _("Semantic layer could not be deleted.")
class SemanticLayerRequiredFieldValidationError(ValidationError):
"""Required field validation error."""
def __init__(self, field_name: str) -> None:
super().__init__([_("Field is required")], field_name=field_name)
class SemanticLayerExistsValidationError(ValidationError):
"""Semantic layer already exists validation error."""
def __init__(self) -> None:
super().__init__(
[_("A semantic layer with this name already exists")],
field_name="name",
)

View File

@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Update semantic layer command."""
from __future__ import annotations
from functools import partial
from typing import Any
from flask_appbuilder.models.sqla import Model
from marshmallow.validate import ValidationError
from superset.commands.base import BaseCommand
from superset.commands.semantic_layer.exceptions import (
SemanticLayerExistsValidationError,
SemanticLayerInvalidError,
SemanticLayerNotFoundError,
SemanticLayerUpdateFailedError,
)
from superset.daos.semantic_layer import SemanticLayerDAO
from superset.semantic_layers.models import SemanticLayer
from superset.utils.decorators import on_error, transaction
class UpdateSemanticLayerCommand(BaseCommand):
"""Command to update a semantic layer."""
def __init__(self, model_id: str, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: SemanticLayer | None = None
@transaction(on_error=partial(on_error, reraise=SemanticLayerUpdateFailedError))
def run(self) -> Model:
"""
Update a semantic layer.
:return: The updated semantic layer
"""
self.validate()
assert self._model
return SemanticLayerDAO.update(self._model, self._properties)
def validate(self) -> None:
"""
Validate the semantic layer update.
:raises SemanticLayerNotFoundError: If semantic layer not found
:raises SemanticLayerInvalidError: If validation fails
"""
exceptions: list[ValidationError] = []
# Find the model
self._model = SemanticLayerDAO.find_by_id(self._model_id)
if not self._model:
raise SemanticLayerNotFoundError()
# Validate uniqueness if name is being changed
if name := self._properties.get("name"):
if not SemanticLayerDAO.validate_update_uniqueness(self._model_id, name):
exceptions.append(SemanticLayerExistsValidationError())
if exceptions:
exception = SemanticLayerInvalidError()
exception.extend(exceptions)
raise exception

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Create semantic view command."""
from __future__ import annotations
from functools import partial
from typing import Any
from flask_appbuilder.models.sqla import Model
from marshmallow.validate import ValidationError
from superset.commands.base import BaseCommand
from superset.commands.semantic_layer.exceptions import SemanticLayerNotFoundError
from superset.commands.semantic_view.exceptions import (
SemanticViewCreateFailedError,
SemanticViewExistsValidationError,
SemanticViewInvalidError,
SemanticViewRequiredFieldValidationError,
)
from superset.daos.semantic_layer import SemanticLayerDAO, SemanticViewDAO
from superset.utils.decorators import on_error, transaction
class CreateSemanticViewCommand(BaseCommand):
"""Command to create a semantic view."""
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=SemanticViewCreateFailedError))
def run(self) -> Model:
"""
Create a semantic view.
:return: The created semantic view
"""
self.validate()
return SemanticViewDAO.create(attributes=self._properties)
def validate(self) -> None:
"""
Validate the semantic view data.
:raises SemanticViewInvalidError: If validation fails
:raises SemanticLayerNotFoundError: If semantic layer not found
"""
exceptions: list[ValidationError] = []
# Validate required fields
if not self._properties.get("name"):
exceptions.append(SemanticViewRequiredFieldValidationError("name"))
layer_uuid = self._properties.get("semantic_layer_uuid")
if not layer_uuid:
exceptions.append(
SemanticViewRequiredFieldValidationError("semantic_layer_uuid")
)
else:
# Validate semantic layer exists
semantic_layer = SemanticLayerDAO.find_by_id(layer_uuid)
if not semantic_layer:
raise SemanticLayerNotFoundError()
# Validate uniqueness within semantic layer
name = self._properties.get("name")
if name and not SemanticViewDAO.validate_uniqueness(name, layer_uuid):
exceptions.append(SemanticViewExistsValidationError())
if exceptions:
exception = SemanticViewInvalidError()
exception.extend(exceptions)
raise exception

View File

@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Delete semantic view command."""
from __future__ import annotations
from functools import partial
from superset.commands.base import BaseCommand
from superset.commands.semantic_view.exceptions import (
SemanticViewDeleteFailedError,
SemanticViewNotFoundError,
)
from superset.daos.semantic_layer import SemanticViewDAO
from superset.semantic_layers.models import SemanticView
from superset.utils.decorators import on_error, transaction
class DeleteSemanticViewCommand(BaseCommand):
"""Command to delete a semantic view."""
def __init__(self, model_id: str):
self._model_id = model_id
self._model: SemanticView | None = None
@transaction(on_error=partial(on_error, reraise=SemanticViewDeleteFailedError))
def run(self) -> None:
"""Delete a semantic view."""
self.validate()
assert self._model
SemanticViewDAO.delete([self._model])
def validate(self) -> None:
"""
Validate the semantic view deletion.
:raises SemanticViewNotFoundError: If semantic view not found
"""
self._model = SemanticViewDAO.find_by_id(self._model_id)
if not self._model:
raise SemanticViewNotFoundError()

View File

@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Exceptions for semantic view commands."""
from flask_babel import lazy_gettext as _
from marshmallow.validate import ValidationError
from superset.commands.exceptions import (
CommandInvalidError,
CreateFailedError,
DeleteFailedError,
ObjectNotFoundError,
UpdateFailedError,
)
class SemanticViewInvalidError(CommandInvalidError):
"""Semantic view parameters are invalid."""
message = _("Semantic view parameters are invalid.")
class SemanticViewNotFoundError(ObjectNotFoundError):
"""Semantic view not found."""
def __init__(self) -> None:
super().__init__("Semantic view", None)
class SemanticViewCreateFailedError(CreateFailedError):
"""Semantic view could not be created."""
message = _("Semantic view could not be created.")
class SemanticViewUpdateFailedError(UpdateFailedError):
"""Semantic view could not be updated."""
message = _("Semantic view could not be updated.")
class SemanticViewDeleteFailedError(DeleteFailedError):
"""Semantic view could not be deleted."""
message = _("Semantic view could not be deleted.")
class SemanticViewRequiredFieldValidationError(ValidationError):
"""Required field validation error."""
def __init__(self, field_name: str) -> None:
super().__init__([_("Field is required")], field_name=field_name)
class SemanticViewExistsValidationError(ValidationError):
"""Semantic view already exists validation error."""
def __init__(self) -> None:
super().__init__(
[_("A semantic view with this name already exists in this semantic layer")],
field_name="name",
)

View File

@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Update semantic view command."""
from __future__ import annotations
from functools import partial
from typing import Any
from flask_appbuilder.models.sqla import Model
from marshmallow.validate import ValidationError
from superset.commands.base import BaseCommand
from superset.commands.semantic_view.exceptions import (
SemanticViewExistsValidationError,
SemanticViewInvalidError,
SemanticViewNotFoundError,
SemanticViewUpdateFailedError,
)
from superset.daos.semantic_layer import SemanticViewDAO
from superset.semantic_layers.models import SemanticView
from superset.utils.decorators import on_error, transaction
class UpdateSemanticViewCommand(BaseCommand):
"""Command to update a semantic view."""
def __init__(self, model_id: str, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: SemanticView | None = None
@transaction(on_error=partial(on_error, reraise=SemanticViewUpdateFailedError))
def run(self) -> Model:
"""
Update a semantic view.
:return: The updated semantic view
"""
self.validate()
assert self._model
return SemanticViewDAO.update(self._model, self._properties)
def validate(self) -> None:
"""
Validate the semantic view update.
:raises SemanticViewNotFoundError: If semantic view not found
:raises SemanticViewInvalidError: If validation fails
"""
exceptions: list[ValidationError] = []
# Find the model
self._model = SemanticViewDAO.find_by_id(self._model_id)
if not self._model:
raise SemanticViewNotFoundError()
# Validate uniqueness if name is being changed
if name := self._properties.get("name"):
if not SemanticViewDAO.validate_update_uniqueness(
self._model_id, name, self._model.semantic_layer_uuid
):
exceptions.append(SemanticViewExistsValidationError())
if exceptions:
exception = SemanticViewInvalidError()
exception.extend(exceptions)
raise exception

View File

@@ -0,0 +1,152 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""DAOs for semantic layer models."""
from __future__ import annotations
from superset.daos.base import BaseDAO
from superset.extensions import db
from superset.semantic_layers.models import SemanticLayer, SemanticView
class SemanticLayerDAO(BaseDAO[SemanticLayer]):
"""
Data Access Object for SemanticLayer model.
"""
@staticmethod
def validate_uniqueness(name: str) -> bool:
"""
Validate that semantic layer name is unique.
:param name: Semantic layer name
:return: True if name is unique, False otherwise
"""
query = db.session.query(SemanticLayer).filter(SemanticLayer.name == name)
return not db.session.query(query.exists()).scalar()
@staticmethod
def validate_update_uniqueness(layer_uuid: str, name: str) -> bool:
"""
Validate that semantic layer name is unique for updates.
:param layer_uuid: UUID of the semantic layer being updated
:param name: New name to validate
:return: True if name is unique, False otherwise
"""
query = db.session.query(SemanticLayer).filter(
SemanticLayer.name == name,
SemanticLayer.uuid != layer_uuid,
)
return not db.session.query(query.exists()).scalar()
@staticmethod
def find_by_name(name: str) -> SemanticLayer | None:
"""
Find semantic layer by name.
:param name: Semantic layer name
:return: SemanticLayer instance or None
"""
return (
db.session.query(SemanticLayer)
.filter(SemanticLayer.name == name)
.one_or_none()
)
@classmethod
def get_semantic_views(cls, layer_uuid: str) -> list[SemanticView]:
"""
Get all semantic views for a semantic layer.
:param layer_uuid: UUID of the semantic layer
:return: List of SemanticView instances
"""
return (
db.session.query(SemanticView)
.filter(SemanticView.semantic_layer_uuid == layer_uuid)
.all()
)
class SemanticViewDAO(BaseDAO[SemanticView]):
"""Data Access Object for SemanticView model."""
@staticmethod
def find_by_semantic_layer(layer_uuid: str) -> list[SemanticView]:
"""
Find all views for a semantic layer.
:param layer_uuid: UUID of the semantic layer
:return: List of SemanticView instances
"""
return (
db.session.query(SemanticView)
.filter(SemanticView.semantic_layer_uuid == layer_uuid)
.all()
)
@staticmethod
def validate_uniqueness(name: str, layer_uuid: str) -> bool:
"""
Validate that view name is unique within semantic layer.
:param name: View name
:param layer_uuid: UUID of the semantic layer
:return: True if name is unique within layer, False otherwise
"""
query = db.session.query(SemanticView).filter(
SemanticView.name == name,
SemanticView.semantic_layer_uuid == layer_uuid,
)
return not db.session.query(query.exists()).scalar()
@staticmethod
def validate_update_uniqueness(view_uuid: str, name: str, layer_uuid: str) -> bool:
"""
Validate that view name is unique within semantic layer for updates.
:param view_uuid: UUID of the view being updated
:param name: New name to validate
:param layer_uuid: UUID of the semantic layer
:return: True if name is unique within layer, False otherwise
"""
query = db.session.query(SemanticView).filter(
SemanticView.name == name,
SemanticView.semantic_layer_uuid == layer_uuid,
SemanticView.uuid != view_uuid,
)
return not db.session.query(query.exists()).scalar()
@staticmethod
def find_by_name(name: str, layer_uuid: str) -> SemanticView | None:
"""
Find semantic view by name within a semantic layer.
:param name: View name
:param layer_uuid: UUID of the semantic layer
:return: SemanticView instance or None
"""
return (
db.session.query(SemanticView)
.filter(
SemanticView.name == name,
SemanticView.semantic_layer_uuid == layer_uuid,
)
.one_or_none()
)

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,248 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Base protocol for explorable data sources in Superset.
An "explorable" is any data source that can be explored to create charts,
including SQL datasets, saved queries, and semantic layer views.
"""
from __future__ import annotations
from collections.abc import Hashable
from datetime import datetime
from typing import Any, Protocol, runtime_checkable
from superset.common.query_object import QueryObject
from superset.models.helpers import QueryResult
from superset.superset_typing import QueryObjectDict
@runtime_checkable
class Explorable(Protocol):
"""
Protocol for objects that can be explored to create charts.
This protocol defines the minimal interface required for a data source
to be visualizable in Superset. It is implemented by:
- BaseDatasource (SQL datasets and queries)
- SemanticView (semantic layer views)
- Future: Other data source types
The protocol focuses on the essential methods and properties needed
for query execution, caching, and security.
"""
# =========================================================================
# Core Query Interface
# =========================================================================
def get_query_result(self, query_object: QueryObject) -> QueryResult:
"""
Execute a query and return results.
This is the primary method for data retrieval. It takes a query
object describing what data to fetch (columns, metrics, filters, time range,
etc.) and returns a QueryResult containing a pandas DataFrame with the results.
:param query_obj: QueryObject describing the query
:return: QueryResult containing:
- df: pandas DataFrame with query results
- query: string representation of the executed query
- duration: query execution time
- status: QueryStatus (SUCCESS/FAILED)
- error_message: error details if query failed
"""
def get_query_str(self, query_obj: QueryObjectDict) -> str:
"""
Get the query string without executing.
Returns a string representation of the query that would be executed
for the given query object. This is used for display in the UI
and debugging.
:param query_obj: Dictionary describing the query
:return: String representation of the query (SQL, GraphQL, etc.)
"""
# =========================================================================
# Identity & Metadata
# =========================================================================
@property
def uid(self) -> str:
"""
Unique identifier for this explorable.
Used as part of cache keys and for tracking. Should be stable
across application restarts but change when the explorable's
data or structure changes.
Format convention: "{type}_{id}" (e.g., "table_123", "semantic_view_abc")
:return: Unique identifier string
"""
@property
def type(self) -> str:
"""
Type discriminator for this explorable.
Identifies the kind of data source (e.g., 'table', 'query', 'semantic_view').
Used for routing and type-specific behavior.
:return: Type identifier string
"""
@property
def columns(self) -> list[Any]:
"""
List of column metadata objects.
Each object should provide at minimum:
- column_name: str - the column's name
- type: str - the column's data type
- is_dttm: bool - whether it's a datetime column
Used for validation, autocomplete, and query building.
:return: List of column metadata objects
"""
@property
def column_names(self) -> list[str]:
"""
List of available column names.
A simple list of all column names in the explorable.
Used for quick validation and filtering.
:return: List of column name strings
"""
@property
def data(self) -> dict[str, Any]:
"""
Full metadata representation sent to the frontend.
This property returns a dictionary containing all the metadata
needed by the Explore UI, including columns, metrics, and
other configuration.
Required keys in the returned dictionary:
- id: unique identifier (int or str)
- uid: unique string identifier
- name: display name
- type: explorable type ('table', 'query', 'semantic_view', etc.)
- columns: list of column metadata dicts (with column_name, type, etc.)
- metrics: list of metric metadata dicts (with metric_name, expression, etc.)
- database: database metadata dict (with id, backend, etc.)
Optional keys:
- description: human-readable description
- schema: schema name (if applicable)
- catalog: catalog name (if applicable)
- cache_timeout: default cache timeout
- offset: timezone offset
- owners: list of owner IDs
- verbose_map: dict mapping column/metric names to display names
:return: Dictionary with complete explorable metadata
"""
# =========================================================================
# Caching
# =========================================================================
@property
def cache_timeout(self) -> int | None:
"""
Default cache timeout in seconds.
Determines how long query results should be cached.
Returns None to use the system default cache timeout.
:return: Cache timeout in seconds, or None for system default
"""
@property
def changed_on(self) -> datetime | None:
"""
Last modification timestamp.
Used for cache invalidation - when this changes, cached
results for this explorable become invalid.
:return: Datetime of last modification, or None
"""
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
"""
Additional cache key components specific to this explorable.
Provides explorable-specific values to include in cache keys.
Used to ensure cache invalidation when the explorable's
underlying data or configuration changes in ways not captured
by uid or changed_on.
:param query_obj: The query being executed
:return: List of additional hashable values for cache key
"""
# =========================================================================
# Security
# =========================================================================
@property
def perm(self) -> str:
"""
Permission string for this explorable.
Used by the security manager to check if a user has access
to this data source. Format depends on the explorable type
(e.g., "[database].[schema].[table]" for SQL tables).
:return: Permission identifier string
"""
@property
def schema_perm(self) -> str | None:
"""
Schema-level permission string.
Optional permission string for schema-level access control.
Some explorables don't have a schema concept and can return None.
:return: Schema permission string, or None
"""
# =========================================================================
# Time/Date Handling
# =========================================================================
@property
def offset(self) -> int:
"""
Timezone offset for datetime columns.
Used to normalize datetime values to the user's timezone.
Returns 0 for UTC, or an offset in seconds.
:return: Timezone offset in seconds (0 for UTC)
"""

View File

@@ -0,0 +1,124 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add_semantic_layers_and_views
Revision ID: 33d7e0e21daa
Revises: c233f5365c9e
Create Date: 2025-11-04 11:26:00.000000
"""
import uuid
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
from sqlalchemy_utils import UUIDType
from superset.migrations.shared.utils import (
create_fks_for_table,
create_table,
drop_table,
)
# revision identifiers, used by Alembic.
revision = "33d7e0e21daa"
down_revision = "c233f5365c9e"
def upgrade():
# Create semantic_layers table
create_table(
"semantic_layers",
sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False),
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("name", sa.String(length=250), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("type", sa.String(length=250), nullable=False),
sa.Column(
"configuration",
sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"),
nullable=True,
),
sa.Column("cache_timeout", sa.Integer(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("uuid"),
)
# Create foreign key constraints for semantic_layers
create_fks_for_table(
"fk_semantic_layers_created_by_fk_ab_user",
"semantic_layers",
"ab_user",
["created_by_fk"],
["id"],
)
create_fks_for_table(
"fk_semantic_layers_changed_by_fk_ab_user",
"semantic_layers",
"ab_user",
["changed_by_fk"],
["id"],
)
# Create semantic_views table
create_table(
"semantic_views",
sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False),
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("name", sa.String(length=250), nullable=False),
sa.Column(
"configuration",
sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"),
nullable=True,
),
sa.Column("cache_timeout", sa.Integer(), nullable=True),
sa.Column(
"semantic_layer_uuid",
UUIDType(binary=True),
sa.ForeignKey("semantic_layers.uuid", ondelete="CASCADE"),
nullable=False,
),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("uuid"),
)
# Create foreign key constraints for semantic_views
create_fks_for_table(
"fk_semantic_views_created_by_fk_ab_user",
"semantic_views",
"ab_user",
["created_by_fk"],
["id"],
)
create_fks_for_table(
"fk_semantic_views_changed_by_fk_ab_user",
"semantic_views",
"ab_user",
["changed_by_fk"],
["id"],
)
def downgrade():
drop_table("semantic_views")
drop_table("semantic_layers")

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,869 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Functions for mapping `QueryObject` to semantic layers.
These functions validate and convert a `QueryObject` into one or more `SemanticQuery`,
which are then passed to semantic layer implementations for execution, returning a
single dataframe.
"""
from datetime import datetime, timedelta
from time import time
from typing import Any, cast, Sequence, TypeGuard
import numpy as np
from superset.common.db_query_status import QueryStatus
from superset.common.query_object import QueryObject
from superset.common.utils.time_range_utils import get_since_until_from_query_object
from superset.connectors.sqla.models import BaseDatasource
from superset.models.helpers import QueryResult
from superset.semantic_layers.types import (
AdhocExpression,
AdhocFilter,
DateGrain,
Dimension,
Filter,
FilterValues,
GroupLimit,
Metric,
Operator,
OrderDirection,
OrderTuple,
PredicateType,
SemanticQuery,
SemanticResult,
SemanticViewFeature,
TimeGrain,
)
from superset.utils.core import (
FilterOperator,
QueryObjectFilterClause,
TIME_COMPARISON,
)
from superset.utils.date_parser import get_past_or_future
class ValidatedQueryObjectFilterClause(QueryObjectFilterClause):
"""
A validated QueryObject filter clause with a string column name.
The `col` in a `QueryObjectFilterClause` can be either a string (column name) or an
adhoc column, but we only support the former in semantic layers.
"""
# overwrite to narrow type; mypy complains about more restrictive typed dicts,
# but the alternative would be to redefine the object
col: str # type: ignore[misc]
op: str # type: ignore[misc]
class ValidatedQueryObject(QueryObject):
"""
A query object that has a datasource defined.
"""
datasource: BaseDatasource
# overwrite to narrow type; mypy complains about the assignment since the base type
# allows adhoc filters, but we only support validated filters here
filter: list[ValidatedQueryObjectFilterClause] # type: ignore[assignment]
series_columns: Sequence[str] # type: ignore[assignment]
series_limit_metric: str | None
def get_results(query_object: QueryObject) -> QueryResult:
"""
Run 1+ queries based on `QueryObject` and return the results.
:param query_object: The QueryObject containing query specifications
:return: QueryResult compatible with Superset's query interface
"""
if not validate_query_object(query_object):
raise ValueError("QueryObject must have a datasource defined.")
# Track execution time
start_time = time()
semantic_view = query_object.datasource.implementation
dispatcher = (
semantic_view.get_row_count
if query_object.is_rowcount
else semantic_view.get_dataframe
)
# Step 1: Convert QueryObject to list of SemanticQuery objects
# The first query is the main query, subsequent queries are for time offsets
queries = map_query_object(query_object)
# Step 2: Execute the main query (first in the list)
main_query = queries[0]
main_result = dispatcher(
metrics=main_query.metrics,
dimensions=main_query.dimensions,
filters=main_query.filters,
order=main_query.order,
limit=main_query.limit,
offset=main_query.offset,
group_limit=main_query.group_limit,
)
main_df = main_result.results
# Collect all requests (SQL queries, HTTP requests, etc.) for troubleshooting
all_requests = list(main_result.requests)
# If no time offsets, return the main result as-is
if not query_object.time_offsets or len(queries) <= 1:
semantic_result = SemanticResult(
requests=all_requests,
results=main_df,
)
duration = timedelta(seconds=time() - start_time)
return map_semantic_result_to_query_result(
semantic_result,
query_object,
duration,
)
# Get metric names from the main query
# These are the columns that will be renamed with offset suffixes
metric_names = [metric.name for metric in main_query.metrics]
# Join keys are all columns except metrics
# These will be used to match rows between main and offset DataFrames
join_keys = [col for col in main_df.columns if col not in metric_names]
# Step 3 & 4: Execute each time offset query and join results
for offset_query, time_offset in zip(
queries[1:],
query_object.time_offsets,
strict=False,
):
# Execute the offset query
result = dispatcher(
metrics=offset_query.metrics,
dimensions=offset_query.dimensions,
filters=offset_query.filters,
order=offset_query.order,
limit=offset_query.limit,
offset=offset_query.offset,
group_limit=offset_query.group_limit,
)
# Add this query's requests to the collection
all_requests.extend(result.requests)
offset_df = result.results
# Handle empty results - add NaN columns directly instead of merging
# This avoids dtype mismatch issues with empty DataFrames
if offset_df.empty:
# Add offset metric columns with NaN values directly to main_df
for metric in metric_names:
offset_col_name = TIME_COMPARISON.join([metric, time_offset])
main_df[offset_col_name] = np.nan
else:
# Rename metric columns with time offset suffix
# Format: "{metric_name}__{time_offset}"
# Example: "revenue" -> "revenue__1 week ago"
offset_df = offset_df.rename(
columns={
metric: TIME_COMPARISON.join([metric, time_offset])
for metric in metric_names
}
)
# Step 5: Perform left join on dimension columns
# This preserves all rows from main_df and adds offset metrics
# where they match
main_df = main_df.merge(
offset_df,
on=join_keys,
how="left",
suffixes=("", "__duplicate"),
)
# Clean up any duplicate columns that might have been created
# (shouldn't happen with proper join keys, but defensive programming)
duplicate_cols = [
col for col in main_df.columns if col.endswith("__duplicate")
]
if duplicate_cols:
main_df = main_df.drop(columns=duplicate_cols)
# Convert final result to QueryResult
semantic_result = SemanticResult(requests=all_requests, results=main_df)
duration = timedelta(seconds=time() - start_time)
return map_semantic_result_to_query_result(
semantic_result,
query_object,
duration,
)
def map_semantic_result_to_query_result(
semantic_result: SemanticResult,
query_object: ValidatedQueryObject,
duration: timedelta,
) -> QueryResult:
"""
Convert a SemanticResult to a QueryResult.
:param semantic_result: Result from the semantic layer
:param query_object: Original QueryObject (for passthrough attributes)
:param duration: Time taken to execute the query
:return: QueryResult compatible with Superset's query interface
"""
# Get the query string from requests (typically one or more SQL queries)
query_str = ""
if semantic_result.requests:
# Join all requests for display (could be multiple for time comparisons)
query_str = "\n\n".join(
f"-- {req.type}\n{req.definition}" for req in semantic_result.requests
)
return QueryResult(
# Core data
df=semantic_result.results,
query=query_str,
duration=duration,
# Template filters - not applicable to semantic layers
# (semantic layers don't use Jinja templates)
applied_template_filters=None,
# Filter columns - not applicable to semantic layers
# (semantic layers handle filter validation internally)
applied_filter_columns=None,
rejected_filter_columns=None,
# Status - always success if we got here
# (errors would raise exceptions before reaching this point)
status=QueryStatus.SUCCESS,
error_message=None,
errors=None,
# Time range - pass through from original query_object
from_dttm=query_object.from_dttm,
to_dttm=query_object.to_dttm,
)
def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]:
"""
Convert a `QueryObject` into a list of `SemanticQuery`.
This function maps the `QueryObject` into query objects that focus less on
visualization and more on semantics.
"""
semantic_view = query_object.datasource.implementation
all_metrics = {metric.name: metric for metric in semantic_view.metrics}
all_dimensions = {
dimension.name: dimension for dimension in semantic_view.dimensions
}
metrics = [all_metrics[metric] for metric in (query_object.metrics or [])]
grain = (
_convert_time_grain(query_object.extras["time_grain_sqla"])
if "time_grain_sqla" in query_object.extras
else None
)
dimensions = [
dimension
for dimension in semantic_view.dimensions
if dimension.name in query_object.columns
and (
# if a grain is specified, only include the time dimension if its grain
# matches the requested grain
grain is None
or dimension.name != query_object.granularity
or dimension.grain == grain
)
]
order = _get_order_from_query_object(query_object, all_metrics, all_dimensions)
limit = query_object.row_limit
offset = query_object.row_offset
group_limit = _get_group_limit_from_query_object(
query_object,
all_metrics,
all_dimensions,
)
queries = []
for time_offset in [None] + query_object.time_offsets:
filters = _get_filters_from_query_object(
query_object,
time_offset,
all_dimensions,
)
queries.append(
SemanticQuery(
metrics=metrics,
dimensions=dimensions,
filters=filters,
order=order,
limit=limit,
offset=offset,
group_limit=group_limit,
)
)
return queries
def _get_filters_from_query_object(
query_object: ValidatedQueryObject,
time_offset: str | None,
all_dimensions: dict[str, Dimension],
) -> set[Filter | AdhocFilter]:
"""
Extract all filters from the query object, including time range filters.
This simplifies the complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm
by converting all time constraints into filters.
"""
filters: set[Filter | AdhocFilter] = set()
# 1. Add fetch values predicate if present
if (
query_object.apply_fetch_values_predicate
and query_object.datasource.fetch_values_predicate
):
filters.add(
AdhocFilter(
type=PredicateType.WHERE,
definition=query_object.datasource.fetch_values_predicate,
)
)
# 2. Add time range filter based on from_dttm/to_dttm
# For time offsets, this automatically calculates the shifted bounds
time_filters = _get_time_filter(query_object, time_offset, all_dimensions)
filters.update(time_filters)
# 3. Add filters from query_object.extras (WHERE and HAVING clauses)
extras_filters = _get_filters_from_extras(query_object.extras)
filters.update(extras_filters)
# 4. Add all other filters from query_object.filter
for filter_ in query_object.filter:
converted_filter = _convert_query_object_filter(filter_, all_dimensions)
if converted_filter:
filters.add(converted_filter)
return filters
def _get_filters_from_extras(extras: dict[str, Any]) -> set[AdhocFilter]:
"""
Extract filters from the extras dict.
The extras dict can contain various keys that affect query behavior:
Supported keys (converted to filters):
- "where": SQL WHERE clause expression (e.g., "customer_id > 100")
- "having": SQL HAVING clause expression (e.g., "SUM(sales) > 1000")
Other keys in extras (handled elsewhere in the mapper):
- "time_grain_sqla": Time granularity (e.g., "P1D", "PT1H")
Handled in _convert_time_grain() and used for dimension grain matching
Note: The WHERE and HAVING clauses from extras are SQL expressions that
are passed through as-is to the semantic layer as AdhocFilter objects.
"""
filters: set[AdhocFilter] = set()
# Add WHERE clause from extras
if where_clause := extras.get("where"):
filters.add(
AdhocFilter(
type=PredicateType.WHERE,
definition=where_clause,
)
)
# Add HAVING clause from extras
if having_clause := extras.get("having"):
filters.add(
AdhocFilter(
type=PredicateType.HAVING,
definition=having_clause,
)
)
return filters
def _get_time_filter(
query_object: ValidatedQueryObject,
time_offset: str | None,
all_dimensions: dict[str, Dimension],
) -> set[Filter]:
"""
Create a time range filter from the query object.
This handles both regular queries and time offset queries, simplifying the
complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm by using the
same time bounds for both the main query and series limit subqueries.
"""
filters: set[Filter] = set()
if not query_object.granularity:
return filters
time_dimension = all_dimensions.get(query_object.granularity)
if not time_dimension:
return filters
# Get the appropriate time bounds based on whether this is a time offset query
from_dttm, to_dttm = _get_time_bounds(query_object, time_offset)
if not from_dttm or not to_dttm:
return filters
# Create a filter with >= and < operators
return {
Filter(
type=PredicateType.WHERE,
column=time_dimension,
operator=Operator.GREATER_THAN_OR_EQUAL,
value=from_dttm,
),
Filter(
type=PredicateType.WHERE,
column=time_dimension,
operator=Operator.LESS_THAN,
value=to_dttm,
),
}
def _get_time_bounds(
query_object: ValidatedQueryObject,
time_offset: str | None,
) -> tuple[datetime | None, datetime | None]:
"""
Get the appropriate time bounds for the query.
For regular queries (time_offset is None), returns from_dttm/to_dttm.
For time offset queries, calculates the shifted bounds.
This simplifies the inner_from_dttm/inner_to_dttm complexity by using
the same bounds for both main queries and series limit subqueries (Option 1).
"""
if time_offset is None:
# Main query: use from_dttm/to_dttm directly
return query_object.from_dttm, query_object.to_dttm
# Time offset query: calculate shifted bounds
# Use from_dttm/to_dttm if available, otherwise try to get from time_range
outer_from = query_object.from_dttm
outer_to = query_object.to_dttm
if not outer_from or not outer_to:
# Fall back to parsing time_range if from_dttm/to_dttm not set
outer_from, outer_to = get_since_until_from_query_object(query_object)
if not outer_from or not outer_to:
return None, None
# Apply the offset to both bounds
offset_from = get_past_or_future(time_offset, outer_from)
offset_to = get_past_or_future(time_offset, outer_to)
return offset_from, offset_to
def _convert_query_object_filter(
filter_: ValidatedQueryObjectFilterClause,
all_dimensions: dict[str, Dimension],
) -> Filter | AdhocFilter | None:
"""
Convert a QueryObject filter dict to a semantic layer Filter or AdhocFilter.
"""
operator_str = filter_["op"]
# Handle TEMPORAL_RANGE filters (these are already handled by _get_time_filter)
if operator_str == FilterOperator.TEMPORAL_RANGE.value:
# Skip - already handled in _get_time_filter
return None
# Handle simple column filters
col = filter_.get("col")
if col not in all_dimensions:
return None
dimension = all_dimensions[col]
val_str = filter_["val"]
value: FilterValues | set[FilterValues]
if val_str is None:
value = None
elif isinstance(val_str, (list, tuple)):
value = set(val_str)
else:
value = val_str
# Map QueryObject operators to semantic layer operators
operator_mapping = {
FilterOperator.EQUALS.value: Operator.EQUALS,
FilterOperator.NOT_EQUALS.value: Operator.NOT_EQUALS,
FilterOperator.GREATER_THAN.value: Operator.GREATER_THAN,
FilterOperator.LESS_THAN.value: Operator.LESS_THAN,
FilterOperator.GREATER_THAN_OR_EQUALS.value: Operator.GREATER_THAN_OR_EQUAL,
FilterOperator.LESS_THAN_OR_EQUALS.value: Operator.LESS_THAN_OR_EQUAL,
FilterOperator.IN.value: Operator.IN,
FilterOperator.NOT_IN.value: Operator.NOT_IN,
FilterOperator.LIKE.value: Operator.LIKE,
FilterOperator.NOT_LIKE.value: Operator.NOT_LIKE,
FilterOperator.IS_NULL.value: Operator.IS_NULL,
FilterOperator.IS_NOT_NULL.value: Operator.IS_NOT_NULL,
}
operator = operator_mapping.get(operator_str)
if not operator:
# Unknown operator - create adhoc filter
return None
return Filter(
type=PredicateType.WHERE,
column=dimension,
operator=operator,
value=value,
)
def _get_order_from_query_object(
query_object: ValidatedQueryObject,
all_metrics: dict[str, Metric],
all_dimensions: dict[str, Dimension],
) -> list[OrderTuple]:
order: list[OrderTuple] = []
for element, ascending in query_object.orderby:
direction = OrderDirection.ASC if ascending else OrderDirection.DESC
# adhoc
if isinstance(element, dict):
if element["sqlExpression"] is not None:
order.append(
(
AdhocExpression(
id=element["label"] or element["sqlExpression"],
definition=element["sqlExpression"],
),
direction,
)
)
elif element in all_dimensions:
order.append((all_dimensions[element], direction))
elif element in all_metrics:
order.append((all_metrics[element], direction))
return order
def _get_group_limit_from_query_object(
query_object: ValidatedQueryObject,
all_metrics: dict[str, Metric],
all_dimensions: dict[str, Dimension],
) -> GroupLimit | None:
# no limit
if query_object.series_limit == 0 or not query_object.columns:
return None
dimensions = [all_dimensions[dim_id] for dim_id in query_object.series_columns]
top = query_object.series_limit
metric = (
all_metrics[query_object.series_limit_metric]
if query_object.series_limit_metric
else None
)
direction = OrderDirection.DESC if query_object.order_desc else OrderDirection.ASC
group_others = query_object.group_others_when_limit_reached
# Check if we need separate filters for the group limit subquery
# This happens when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm
group_limit_filters = _get_group_limit_filters(query_object, all_dimensions)
return GroupLimit(
dimensions=dimensions,
top=top,
metric=metric,
direction=direction,
group_others=group_others,
filters=group_limit_filters,
)
def _get_group_limit_filters(
query_object: ValidatedQueryObject,
all_dimensions: dict[str, Dimension],
) -> set[Filter | AdhocFilter] | None:
"""
Get separate filters for the group limit subquery if needed.
This is used when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm,
which happens during time comparison queries. The group limit subquery may need
different time bounds to determine the top N groups.
Returns None if the group limit should use the same filters as the main query.
"""
# Check if inner time bounds are explicitly set and differ from outer bounds
if (
query_object.inner_from_dttm is None
or query_object.inner_to_dttm is None
or (
query_object.inner_from_dttm == query_object.from_dttm
and query_object.inner_to_dttm == query_object.to_dttm
)
):
# No separate bounds needed - use the same filters as the main query
return None
# Create separate filters for the group limit subquery
filters: set[Filter | AdhocFilter] = set()
# Add time range filter using inner bounds
if query_object.granularity:
time_dimension = all_dimensions.get(query_object.granularity)
if (
time_dimension
and query_object.inner_from_dttm
and query_object.inner_to_dttm
):
filters.update(
{
Filter(
type=PredicateType.WHERE,
column=time_dimension,
operator=Operator.GREATER_THAN_OR_EQUAL,
value=query_object.inner_from_dttm,
),
Filter(
type=PredicateType.WHERE,
column=time_dimension,
operator=Operator.LESS_THAN,
value=query_object.inner_to_dttm,
),
}
)
# Add fetch values predicate if present
if (
query_object.apply_fetch_values_predicate
and query_object.datasource.fetch_values_predicate
):
filters.add(
AdhocFilter(
type=PredicateType.WHERE,
definition=query_object.datasource.fetch_values_predicate,
)
)
# Add filters from query_object.extras (WHERE and HAVING clauses)
extras_filters = _get_filters_from_extras(query_object.extras)
filters.update(extras_filters)
# Add all other non-temporal filters from query_object.filter
for filter_ in query_object.filter:
# Skip temporal range filters - we're using inner bounds instead
if filter_.get("op") == FilterOperator.TEMPORAL_RANGE.value:
continue
converted_filter = _convert_query_object_filter(filter_, all_dimensions)
if converted_filter:
filters.add(converted_filter)
return filters if filters else None
def _convert_time_grain(time_grain: str) -> TimeGrain | DateGrain | None:
"""
Convert a time grain string from the query object to a TimeGrain or DateGrain enum.
"""
if time_grain in TimeGrain.__members__:
return TimeGrain[time_grain]
if time_grain in DateGrain.__members__:
return DateGrain[time_grain]
return None
def validate_query_object(
query_object: QueryObject,
) -> TypeGuard[ValidatedQueryObject]:
"""
Validate that the `QueryObject` is compatible with the `SemanticView`.
If some semantic view implementation supports these features we should add an
attribute to the `SemanticViewImplementation` to indicate support for them.
"""
if not query_object.datasource:
return False
query_object = cast(ValidatedQueryObject, query_object)
_validate_metrics(query_object)
_validate_dimensions(query_object)
_validate_filters(query_object)
_validate_granularity(query_object)
_validate_group_limit(query_object)
_validate_orderby(query_object)
return True
def _validate_metrics(query_object: ValidatedQueryObject) -> None:
"""
Make sure metrics are defined in the semantic view.
"""
semantic_view = query_object.datasource.implementation
if any(not isinstance(metric, str) for metric in (query_object.metrics or [])):
raise ValueError("Adhoc metrics are not supported in Semantic Views.")
metric_names = {metric.name for metric in semantic_view.metrics}
if not set(query_object.metrics or []) <= metric_names:
raise ValueError("All metrics must be defined in the Semantic View.")
def _validate_dimensions(query_object: ValidatedQueryObject) -> None:
"""
Make sure all dimensions are defined in the semantic view.
"""
semantic_view = query_object.datasource.implementation
if any(not isinstance(column, str) for column in query_object.columns):
raise ValueError("Adhoc dimensions are not supported in Semantic Views.")
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
if not set(query_object.columns) <= dimension_names:
raise ValueError("All dimensions must be defined in the Semantic View.")
def _validate_filters(query_object: ValidatedQueryObject) -> None:
"""
Make sure all filters are valid.
"""
for filter_ in query_object.filter:
if isinstance(filter_["col"], dict):
raise ValueError(
"Adhoc columns are not supported in Semantic View filters."
)
if not filter_.get("op"):
raise ValueError("All filters must have an operator defined.")
def _validate_granularity(query_object: ValidatedQueryObject) -> None:
"""
Make sure time column and time grain are valid.
"""
semantic_view = query_object.datasource.implementation
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
if time_column := query_object.granularity:
if time_column not in dimension_names:
raise ValueError(
"The time column must be defined in the Semantic View dimensions."
)
if time_grain := query_object.extras.get("time_grain_sqla"):
if not time_column:
raise ValueError(
"A time column must be specified when a time grain is provided."
)
supported_time_grains = {
dimension.grain
for dimension in semantic_view.dimensions
if dimension.name == time_column and dimension.grain
}
if _convert_time_grain(time_grain) not in supported_time_grains:
raise ValueError(
"The time grain is not supported for the time column in the "
"Semantic View."
)
def _validate_group_limit(query_object: ValidatedQueryObject) -> None:
"""
Validate group limit related features in the query object.
"""
semantic_view = query_object.datasource.implementation
# no limit
if query_object.series_limit == 0:
return
if (
query_object.series_columns
and SemanticViewFeature.GROUP_LIMIT not in semantic_view.features
):
raise ValueError("Group limit is not supported in this Semantic View.")
if any(not isinstance(col, str) for col in query_object.series_columns):
raise ValueError("Adhoc dimensions are not supported in series columns.")
metric_names = {metric.name for metric in semantic_view.metrics}
if query_object.series_limit_metric and (
not isinstance(query_object.series_limit_metric, str)
or query_object.series_limit_metric not in metric_names
):
raise ValueError(
"The series limit metric must be defined in the Semantic View."
)
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
if not set(query_object.series_columns) <= dimension_names:
raise ValueError("All series columns must be defined in the Semantic View.")
if (
query_object.group_others_when_limit_reached
and SemanticViewFeature.GROUP_OTHERS not in semantic_view.features
):
raise ValueError(
"Grouping others when limit is reached is not supported in this Semantic "
"View."
)
def _validate_orderby(query_object: ValidatedQueryObject) -> None:
"""
Validate order by elements in the query object.
"""
semantic_view = query_object.datasource.implementation
if (
any(not isinstance(element, str) for element, _ in query_object.orderby)
and SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY
not in semantic_view.features
):
raise ValueError(
"Adhoc expressions in order by are not supported in this Semantic View."
)
elements = set(query_object.orderby)
metric_names = {metric.name for metric in semantic_view.metrics}
dimension_names = {dimension.name for dimension in semantic_view.dimensions}
if not elements <= metric_names | dimension_names:
raise ValueError("All order by elements must be defined in the Semantic View.")

View File

@@ -0,0 +1,205 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Semantic layer models."""
from __future__ import annotations
import uuid
from importlib.metadata import entry_points
from typing import Any
from flask_appbuilder import Model
from sqlalchemy import Column, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship
from sqlalchemy_utils import UUIDType
from superset.common.query_object import QueryObject
from superset.models.helpers import AuditMixinNullable, QueryResult
from superset.semantic_layers.mapper import get_results
from superset.semantic_layers.types import (
DATE,
DATETIME,
SemanticLayerImplementation,
SemanticViewImplementation,
TIME,
)
from superset.superset_typing import QueryObjectDict
from superset.utils import core as utils
@dataclass(frozen=True)
class ColumnMetadata:
column_name: str
type: str
is_dttm: bool
class SemanticLayer(AuditMixinNullable, Model):
"""
Semantic layer model.
A semantic layer provides an abstraction over data sources,
allowing users to query data through a semantic interface.
"""
__tablename__ = "semantic_layers"
uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4)
# Core fields
name = Column(String(250), nullable=False)
description = Column(Text, nullable=True)
type = Column(String(250), nullable=False)
# XXX: encrypt at rest
configuration = Column(utils.MediumText(), default="{}")
cache_timeout = Column(Integer, nullable=True)
# Semantic views relationship
semantic_views: list[SemanticView] = relationship(
"SemanticView",
back_populates="semantic_layer",
cascade="all, delete-orphan",
passive_deletes=True,
)
def __repr__(self) -> str:
return self.name or str(self.uuid)
@property
def implementation(
self,
) -> SemanticLayerImplementation[Any, SemanticViewImplementation]:
"""
Return semantic layer implementation.
"""
entry_point = next(
iter(
entry_points(
group="superset.semantic_layers",
name=self.type,
)
)
)
implementation_class = entry_point.load()
if not issubclass(implementation_class, SemanticLayerImplementation):
raise TypeError(
f"Entry point for semantic layer type '{self.type}' "
"must be a subclass of SemanticLayerImplementation"
)
# XXX store in self._implementation
return implementation_class.from_configuration(self.configuration)
class SemanticView(AuditMixinNullable, Model):
"""
Semantic view model.
A semantic view represents a queryable view within a semantic layer.
"""
__tablename__ = "semantic_views"
uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4)
# Core fields
name = Column(String(250), nullable=False)
# XXX: encrypt at rest
configuration = Column(utils.MediumText(), default="{}")
cache_timeout = Column(Integer, nullable=True)
# Semantic layer relationship
semantic_layer_uuid = Column(
UUIDType(binary=True),
ForeignKey("semantic_layers.uuid", ondelete="CASCADE"),
nullable=False,
)
semantic_layer: SemanticLayer = relationship(
"SemanticLayer",
back_populates="semantic_views",
foreign_keys=[semantic_layer_uuid],
)
def __repr__(self) -> str:
return self.name or str(self.uuid)
@property
def implementation(self) -> SemanticViewImplementation:
"""
Return semantic view implementation.
"""
# XXX store in self._implementation
return self.semantic_layer.implementation.get_semantic_view(
self.name,
self.configuration,
)
# =========================================================================
# Explorable protocol implementation
# =========================================================================
def get_query_result(self, query_object: QueryObject) -> QueryResult:
return get_results(query_object)
def get_query_str(self, query_obj: QueryObjectDict) -> str:
return "Not implemented for semantic layers"
@property
def uid(self) -> str:
return self.implementation.uid()
@property
def type(self) -> str:
return "semantic_view"
@property
def columns(self) -> list[ColumnMetadata]:
return [
ColumnMetadata(
column_name=dimension.name,
type=dimension.type.__name__,
is_dttm=dimension.type in {DATE, TIME, DATETIME},
)
for dimension in self.implementation.dimensions
]
@property
def column_names(self) -> list[str]:
return [dimension.name for dimension in self.implementation.dimensions]
@property
def data(self) -> dict[str, Any]:
return {
"id": str(self.uuid),
"uid": self.uid,
"name": self.name,
"type": self.type,
"columns": [],
"metrics": [],
"database": [],
"description": self.description,
"schema": None,
"catalog": None,
"cache_timeout": self.cache_timeout,
"offset": None, # XXX
"owners": [], # XXX
"verbose_map": {}, # XXX
}

View File

@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration
from superset.semantic_layers.snowflake.semantic_layer import SnowflakeSemanticLayer
from superset.semantic_layers.snowflake.semantic_view import SnowflakeSemanticView
__all__ = [
"SnowflakeConfiguration",
"SnowflakeSemanticLayer",
"SnowflakeSemanticView",
]

View File

@@ -0,0 +1,130 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Literal, Union
from pydantic import BaseModel, ConfigDict, Field, model_validator, SecretStr
class UserPasswordAuth(BaseModel):
"""
Username and password authentication.
"""
model_config = ConfigDict(title="Username and password")
auth_type: Literal["user_password"] = "user_password"
username: str = Field(description="The username to authenticate as.")
password: SecretStr = Field(
description="The password to authenticate with.",
repr=False,
)
class PrivateKeyAuth(BaseModel):
"""
Private key authentication.
"""
model_config = ConfigDict(title="Private key")
auth_type: Literal["private_key"] = "private_key"
private_key: SecretStr = Field(
description="The private key to authenticate with, in PEM format.",
repr=False,
)
private_key_password: SecretStr = Field(
description="The password to decrypt the private key with.",
repr=False,
)
class SnowflakeConfiguration(BaseModel):
"""
Parameters needed to connect to Snowflake.
"""
# account is the only required parameter
account_identifier: str = Field(
description="The Snowflake account identifier.",
json_schema_extra={"examples": ["abc12345"]},
)
role: str | None = Field(
default=None,
description="The default role to use.",
json_schema_extra={"examples": ["myrole"]},
)
warehouse: str | None = Field(
default=None,
description="The default warehouse to use.",
json_schema_extra={"examples": ["testwh"]},
)
auth: Union[UserPasswordAuth, PrivateKeyAuth] = Field(
discriminator="auth_type",
description="Authentication method",
)
# database and schema can be optionally provided; if not provided the user
# will be able to browse databases/schemas
database: str | None = Field(
default=None,
description="The default database to use.",
json_schema_extra={
"examples": ["testdb"],
"x-dynamic": True,
"x-dependsOn": ["account_identifier", "auth"],
},
)
allow_changing_database: bool = Field(
default=False,
description="Allow changing the default database.",
)
schema_: str | None = Field(
default=None,
description="The default schema to use.",
json_schema_extra={
"examples": ["public"],
"x-dynamic": True,
"x-dependsOn": ["account_identifier", "auth", "database"],
},
# `schema` is an attribute of `BaseModel` so it needs to be aliased
alias="schema",
)
allow_changing_schema: bool = Field(
default=False,
description="Allow changing the default schema.",
)
@model_validator(mode="after")
def validate_database_schema_settings(self) -> SnowflakeConfiguration:
"""
Validate that if database or schema is not specified, the corresponding
allow_changing flag must be true.
"""
if not self.database and not self.allow_changing_database:
raise ValueError(
"If no database is specified, allow_changing_database must be true"
)
if not self.schema_ and not self.allow_changing_schema:
raise ValueError(
"If no schema is specified, allow_changing_schema must be true"
)
return self

View File

@@ -0,0 +1,236 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from textwrap import dedent
from typing import Any, Literal, TYPE_CHECKING
from pydantic import create_model, Field
from snowflake.connector import connect
from snowflake.connector.connection import SnowflakeConnection
from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration
from superset.semantic_layers.snowflake.utils import get_connection_parameters
from superset.semantic_layers.types import (
SemanticLayerImplementation,
)
if TYPE_CHECKING:
from superset.semantic_layers.snowflake.semantic_view import SnowflakeSemanticView
class SnowflakeSemanticLayer(
SemanticLayerImplementation[SnowflakeConfiguration, SnowflakeSemanticView]
):
id = "snowflake"
name = "Snowflake Semantic Layer"
description = "Connect to semantic views stored in Snowflake."
@classmethod
def from_configuration(
cls,
configuration: dict[str, Any],
) -> SnowflakeSemanticLayer:
"""
Create a SnowflakeSemanticLayer from a configuration dictionary.
"""
config = SnowflakeConfiguration.model_validate(configuration)
return cls(config)
@classmethod
def get_configuration_schema(
cls,
configuration: SnowflakeConfiguration | None = None,
) -> dict[str, Any]:
"""
Get the JSON schema for the configuration needed to add the semantic layer.
A partial configuration can be sent to improve the schema. For example,
providing account and auth will allow the schema to provide a list of
databases; providing a database will allow the schema to provide a list of
schemas.
Note that database and schema can both be left empty when the semantic layer is
added to Superset; the user will then have to provide them when loading
semantic views.
"""
schema = SnowflakeConfiguration.model_json_schema()
properties = schema["properties"]
if configuration is None:
# set these to empty; they will be populated when a partial configuration is
# passed
properties["database"]["enum"] = []
properties["schema"]["enum"] = []
return schema
connection_parameters = get_connection_parameters(configuration)
with connect(**connection_parameters) as connection:
if all(
getattr(configuration, dependency)
for dependency in properties["database"].get("x-dependsOn", [])
):
options = cls._fetch_databases(connection)
properties["database"]["enum"] = list(options)
if (
all(
getattr(configuration, dependency)
for dependency in properties["schema"].get("x-dependsOn", [])
)
and configuration.database
):
options = cls._fetch_schemas(connection, configuration.database)
properties["schema"]["enum"] = list(options)
return schema
@classmethod
def get_runtime_schema(
cls,
configuration: SnowflakeConfiguration,
runtime_data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Get the JSON schema for the runtime parameters needed to load semantic views.
The schema can be enriched with actual values when `runtime_data` is provided,
enabling dynamic schema updates (e.g., populating schema dropdown after
database is selected).
"""
fields: dict[str, tuple[Any, Field]] = {}
# update configuration with runtime data, for example, to select a schema after
# the database has been selected
configuration = configuration.model_copy(update=runtime_data)
connection_parameters = get_connection_parameters(configuration)
with connect(**connection_parameters) as connection:
if not configuration.database or configuration.allow_changing_database:
options = cls._fetch_databases(connection)
fields["database"] = (
Literal[*options],
Field(description="The default database to use."),
)
if not configuration.schema_ or configuration.allow_changing_schema:
if configuration.database:
options = cls._fetch_schemas(connection, configuration.database)
fields["schema_"] = (
Literal[*options],
Field(
description="The default schema to use.",
alias="schema",
json_schema_extra=(
{
"x-dynamic": True,
"x-dependsOn": ["database"],
}
if "database" in fields
else {}
),
),
)
else:
# Database not provided yet, add schema as empty
# (will be populated dynamically)
fields["schema_"] = (
str | None,
Field(
default=None,
description="The default schema to use.",
alias="schema",
json_schema_extra={
"x-dynamic": True,
"x-dependsOn": ["database"],
},
),
)
return create_model("RuntimeParameters", **fields).model_json_schema()
@classmethod
def _fetch_databases(cls, connection: SnowflakeConnection) -> set[str]:
"""
Fetch the list of databases available in the Snowflake account.
We use `SHOW DATABASES` instead of querying the information schema since it
allows to retrieve the list of databases without having to specify a database
when connecting.
"""
cursor = connection.cursor()
cursor.execute("SHOW DATABASES")
return {row[1] for row in cursor}
@classmethod
def _fetch_schemas(
cls,
connection: SnowflakeConnection,
database: str | None,
) -> set[str]:
"""
Fetch the list of schemas available in a given database.
The connection should already have the database set in its context.
"""
if not database:
return set()
cursor = connection.cursor()
query = dedent(
"""
SELECT SCHEMA_NAME
FROM INFORMATION_SCHEMA.SCHEMATA
WHERE CATALOG_NAME = ?
"""
).strip()
return {row[0] for row in cursor.execute(query, (database,))}
def __init__(self, configuration: SnowflakeConfiguration):
self.configuration = configuration
def get_semantic_views(
self,
runtime_configuration: dict[str, Any],
) -> set[SnowflakeSemanticView]:
"""
Get the semantic views available in the semantic layer.
"""
# Avoid circular import
from superset.semantic_layers.snowflake.semantic_view import (
SnowflakeSemanticView,
)
# create a new configuration with the runtime parameters
configuration = self.configuration.model_copy(update=runtime_configuration)
connection_parameters = get_connection_parameters(configuration)
with connect(**connection_parameters) as connection:
cursor = connection.cursor()
query = dedent(
"""
SHOW SEMANTIC VIEWS
->> SELECT "name" FROM $1;
"""
).strip()
views = {
SnowflakeSemanticView(row[0], configuration)
for row in cursor.execute(query)
}
return views

View File

@@ -0,0 +1,817 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: S608
from __future__ import annotations
import itertools
import re
from collections import defaultdict
from textwrap import dedent
from pandas import DataFrame
from snowflake.connector import connect, DictCursor
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect
from superset.semantic_layers.snowflake.schemas import SnowflakeConfiguration
from superset.semantic_layers.snowflake.utils import (
get_connection_parameters,
substitute_parameters,
validate_order_by,
)
from superset.semantic_layers.types import (
AdhocExpression,
AdhocFilter,
BINARY,
BOOLEAN,
DATE,
DATETIME,
DECIMAL,
Dimension,
Filter,
FilterValues,
GroupLimit,
INTEGER,
Metric,
NUMBER,
OBJECT,
Operator,
OrderTuple,
PredicateType,
SemanticRequest,
SemanticResult,
SemanticViewFeature,
SemanticViewImplementation,
STRING,
TIME,
Type,
)
REQUEST_TYPE = "snowflake"
class SnowflakeSemanticView(SemanticViewImplementation):
features = frozenset(
{
SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY,
SemanticViewFeature.GROUP_LIMIT,
SemanticViewFeature.GROUP_OTHERS,
}
)
def __init__(self, name: str, configuration: SnowflakeConfiguration):
self.configuration = configuration
self.name = name
self._quote = SnowflakeDialect().identifier_preparer.quote
self.dimensions = self.get_dimensions()
self.metrics = self.get_metrics()
def uid(self) -> str:
return ".".join(
self._quote(part)
for part in (
self.configuration.database,
self.configuration.schema_,
self.name,
)
)
def get_dimensions(self) -> set[Dimension]:
"""
Get the dimensions defined in the semantic view.
Even though Snowflake supports `SHOW SEMANTIC DIMENSIONS IN my_semantic_view`,
it doesn't return the expression of dimensions, so we use a slightly more
complicated query to get all the information we need in one go.
"""
dimensions: set[Dimension] = set()
query = dedent(
f"""
DESC SEMANTIC VIEW {self.uid()}
->> SELECT "object_name", "property", "property_value"
FROM $1
WHERE
"object_kind" = 'DIMENSION' AND
"property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE');
"""
).strip()
connection_parameters = get_connection_parameters(self.configuration)
with connect(**connection_parameters) as connection:
cursor = connection.cursor(DictCursor)
rows = cursor.execute(query).fetchall()
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
attributes = defaultdict(set)
for row in group:
attributes[row["property"]].add(row["property_value"])
table = next(iter(attributes["TABLE"]))
id_ = table + "." + name
type_ = self._get_type(next(iter(attributes["DATA_TYPE"])))
description = next(iter(attributes["COMMENT"]), None)
definition = next(iter(attributes["EXPRESSION"]), None)
dimensions.add(Dimension(id_, name, type_, description, definition))
return dimensions
def get_metrics(self) -> set[Metric]:
"""
Get the metrics defined in the semantic view.
"""
metrics: set[Metric] = set()
query = dedent(
f"""
DESC SEMANTIC VIEW {self.uid()}
->> SELECT "object_name", "property", "property_value"
FROM $1
WHERE
"object_kind" = 'METRIC' AND
"property" IN ('COMMENT', 'DATA_TYPE', 'EXPRESSION', 'TABLE');
"""
).strip()
connection_parameters = get_connection_parameters(self.configuration)
with connect(**connection_parameters) as connection:
cursor = connection.cursor(DictCursor)
rows = cursor.execute(query).fetchall()
for name, group in itertools.groupby(rows, key=lambda x: x["object_name"]):
attributes = defaultdict(set)
for row in group:
attributes[row["property"]].add(row["property_value"])
table = next(iter(attributes["TABLE"]))
id_ = table + "." + name
type_ = self._get_type(next(iter(attributes["DATA_TYPE"])))
description = next(iter(attributes["COMMENT"]), None)
definition = next(iter(attributes["EXPRESSION"]), None)
metrics.add(Metric(id_, name, type_, definition, description))
return metrics
def _get_type(self, snowflake_type: str | None) -> type[Type]:
"""
Return the semantic type corresponding to a Snowflake type.
"""
if snowflake_type is None:
return STRING
type_map = {
STRING: {r"VARCHAR\(\d+\)$", "STRING$", "TEXT$", r"CHAR\(\d+\)$"},
INTEGER: {r"NUMBER\(38,\s?0\)$", "INT$", "INTEGER$", "BIGINT$"},
DECIMAL: {r"NUMBER\(10,\s?2\)$"},
NUMBER: {r"NUMBER\(\d+,\s?\d+\)$", "FLOAT$", "DOUBLE$"},
BOOLEAN: {"BOOLEAN$"},
DATE: {"DATE$"},
DATETIME: {"TIMESTAMP_TZ$", "TIMESTAMP__NTZ$"},
TIME: {"TIME$"},
OBJECT: {"OBJECT$"},
BINARY: {r"BINARY\(\d+\)$", r"VARBINARY\(\d+\)$"},
}
for semantic_type, patterns in type_map.items():
if any(
re.match(pattern, snowflake_type, re.IGNORECASE) for pattern in patterns
):
return semantic_type
return STRING
def _build_predicates(
self,
filters: list[Filter | AdhocFilter],
) -> tuple[str, tuple[FilterValues, ...]]:
"""
Convert a set of filters to a single `AND`ed predicate.
Caller should check the types of filters beforehand, as this method does not
differentiate between `WHERE` and `HAVING` predicates.
"""
if not filters:
return "", ()
# convert filters predicate with associated parameters; native filters are
# already strings, so we keep them as-is
unary_operators = {Operator.IS_NULL, Operator.IS_NOT_NULL}
predicates: list[str] = []
parameters: list[FilterValues] = []
for filter_ in filters or set():
if isinstance(filter_, AdhocFilter):
predicates.append(f"({filter_.definition})")
else:
predicates.append(f"({self._build_native_filter(filter_)})")
if filter_.operator not in unary_operators:
parameters.extend(
[filter_.value]
if not isinstance(filter_.value, (set, frozenset))
else filter_.value
)
return " AND ".join(predicates), tuple(parameters)
def get_values(
self,
dimension: Dimension,
filters: set[Filter | AdhocFilter] | None = None,
) -> SemanticResult:
"""
Return distinct values for a dimension.
"""
where_clause, parameters = self._build_predicates(
sorted(
filter_
for filter_ in (filters or [])
if filter_.type == PredicateType.WHERE
)
)
query = dedent(
f"""
SELECT {self._quote(dimension.name)}
FROM SEMANTIC_VIEW(
{self.uid()}
DIMENSIONS {dimension.id}
{"WHERE " + where_clause if where_clause else ""}
)
"""
).strip()
connection_parameters = get_connection_parameters(self.configuration)
with connect(**connection_parameters) as connection:
df = connection.cursor().execute(query, parameters).fetch_pandas_all()
return SemanticResult(
requests=[
SemanticRequest(
REQUEST_TYPE,
substitute_parameters(query, parameters),
)
],
results=df,
)
def _build_native_filter(self, filter_: Filter) -> str:
"""
Convert a Filter to a AdhocFilter.
"""
column = filter_.column
operator = filter_.operator
value = filter_.value
column_name = self._quote(column.name)
# Handle IS NULL and IS NOT NULL operators (no value needed)
if operator in {Operator.IS_NULL, Operator.IS_NOT_NULL}:
return f"{column_name} {operator.value}"
# Handle IN and NOT IN operators (set values)
if operator in {Operator.IN, Operator.NOT_IN}:
parameter_count = len(value) if isinstance(value, (set, frozenset)) else 1
formatted_values = ", ".join("?" for _ in range(parameter_count))
return f"{column_name} {operator.value} ({formatted_values})"
return f"{column_name} {operator.value} ?"
def get_dataframe(
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
*,
group_limit: GroupLimit | None = None,
) -> SemanticResult:
"""
Execute a query and return the results as a Pandas DataFrame.
"""
if not metrics and not dimensions:
return DataFrame()
query, parameters = self._get_query(
metrics,
dimensions,
filters,
order,
limit,
offset,
group_limit,
)
connection_parameters = get_connection_parameters(self.configuration)
with connect(**connection_parameters) as connection:
df = connection.cursor().execute(query, parameters).fetch_pandas_all()
return SemanticResult(
requests=[
SemanticRequest(
REQUEST_TYPE,
substitute_parameters(query, parameters),
)
],
results=df,
)
def get_row_count(
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
*,
group_limit: GroupLimit | None = None,
) -> SemanticResult:
"""
Execute a query and return the number of rows the result would have.
"""
if not metrics and not dimensions:
return SemanticResult(
requests=[],
results=DataFrame([[0]], columns=["COUNT"]),
)
query, parameters = self._get_query(
metrics,
dimensions,
filters,
order,
limit,
offset,
group_limit,
)
query = f"SELECT COUNT(*) FROM ({query}) AS subquery"
connection_parameters = get_connection_parameters(self.configuration)
with connect(**connection_parameters) as connection:
df = connection.cursor().execute(query, parameters).fechone()[0]
return SemanticResult(
requests=[
SemanticRequest(
REQUEST_TYPE,
substitute_parameters(query, parameters),
)
],
results=df,
)
def _get_query(
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
group_limit: GroupLimit | None = None,
) -> tuple[str, tuple[FilterValues, ...]]:
"""
Build a query to fetch data from the semantic view.
This also returns the parameters need to run `cursor.execute()`, passed
separately to prevent SQL injection.
"""
if limit is None and offset is not None:
raise ValueError("Offset cannot be set without limit")
filters = filters or set()
where_clause, where_parameters = self._build_predicates(
sorted(
filter_ for filter_ in filters if filter_.type == PredicateType.WHERE
)
)
# having clauses are not supported, since there's no GROUP BY
if any(filter_.type == PredicateType.HAVING for filter_ in filters):
raise ValueError("HAVING filters are not supported")
if group_limit:
query, cte_parameters = self._build_query_with_group_limit(
metrics,
dimensions,
where_clause,
order,
limit,
offset,
group_limit,
)
# Combine parameters: CTE params first, then main query params
all_parameters = cte_parameters + where_parameters
else:
query = self._build_simple_query(
metrics,
dimensions,
where_clause,
order,
limit,
offset,
)
all_parameters = where_parameters
return query, all_parameters
def _alias_element(self, element: Metric | Dimension) -> str:
"""
Generate an aliased column expression for a metric or dimension.
"""
return f"{element.id} AS {self._quote(element.id)}"
def _build_order_clause(
self,
order: list[OrderTuple] | None = None,
) -> str:
"""
Build the ORDER BY clause from a list of (element, direction) tuples.
Note that for adhoc expressions, Superset will still add `ASC` or `DESC` to the
end, which means adhoc expressions can contain multiple columns as long as the
last one has no direction specified.
This is fine:
gender ASC, COUNT(*)
But this is not
gender ASC, COUNT(*) DESC
The latter will produce a query that looks like this:
... ORDER BY gender ASC, COUNT(*) DESC DESC
"""
if not order:
return ""
def build_element(element: Metric | Dimension | AdhocExpression) -> str:
if isinstance(element, AdhocExpression):
validate_order_by(element.definition)
return element.definition
return self._quote(element.id)
return ", ".join(
f"{build_element(element)} {direction.value}"
for element, direction in order
)
def _build_simple_query(
self,
metrics: list[Metric],
dimensions: list[Dimension],
where_clause: str,
order: list[OrderTuple] | None,
limit: int | None,
offset: int | None,
) -> str:
"""
Build a query without group limiting.
"""
dimension_arguments = ", ".join(
self._alias_element(dimension) for dimension in dimensions
)
metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics)
order_clause = self._build_order_clause(order)
return dedent(
f"""
SELECT * FROM SEMANTIC_VIEW(
{self.uid()}
{"DIMENSIONS " + dimension_arguments if dimension_arguments else ""}
{"METRICS " + metric_arguments if metric_arguments else ""}
{"WHERE " + where_clause if where_clause else ""}
)
{"ORDER BY " + order_clause if order_clause else ""}
{"LIMIT " + str(limit) if limit is not None else ""}
{"OFFSET " + str(offset) if offset is not None else ""}
"""
).strip()
def _build_top_groups_cte(
self,
group_limit: GroupLimit,
where_clause: str,
) -> tuple[str, tuple[FilterValues, ...]]:
"""
Build a CTE that finds the top N combinations of limited dimensions.
If group_limit.filters is set, it uses those filters instead of the main
query's where clause. This allows using different time bounds for finding top
groups vs showing data.
Returns:
Tuple of (CTE SQL, parameters for the CTE)
"""
limited_dimension_arguments = ", ".join(
self._alias_element(dimension) for dimension in group_limit.dimensions
)
limited_dimension_names = ", ".join(
self._quote(dimension.id) for dimension in group_limit.dimensions
)
# Use separate filters for group limit if provided (Option 2)
# Otherwise use the same filters as the main query (Option 1)
if group_limit.filters is not None:
group_where_clause, group_where_params = self._build_predicates(
sorted(
filter_
for filter_ in group_limit.filters
if filter_.type == PredicateType.WHERE
)
)
if any(
filter_.type == PredicateType.HAVING for filter_ in group_limit.filters
):
raise ValueError(
"HAVING filters are not supported in group limit filters"
)
cte_params = group_where_params
else:
group_where_clause = where_clause
cte_params = () # No additional params - using main query params
# Build METRICS clause and ORDER BY based on whether metric is provided
if group_limit.metric is not None:
metrics_clause = (
f"METRICS {group_limit.metric.id}"
f" AS {self._quote(group_limit.metric.id)}"
)
order_by_clause = (
f"{self._quote(group_limit.metric.id)} {group_limit.direction.value}"
)
else:
# No metric provided - order by first dimension
metrics_clause = ""
order_by_clause = (
f"{self._quote(group_limit.dimensions[0].id)} "
f"{group_limit.direction.value}"
)
# Build SEMANTIC_VIEW arguments
semantic_view_args = [
f"DIMENSIONS {limited_dimension_arguments}",
]
if metrics_clause:
semantic_view_args.append(metrics_clause)
if group_where_clause:
semantic_view_args.append(f"WHERE {group_where_clause}")
semantic_view_args_str = "\n ".join(semantic_view_args)
# Add trailing blank line if there's no WHERE clause
# This matches the original template behavior
if not group_where_clause:
semantic_view_args_str += "\n"
cte_sql = dedent(
f"""
WITH top_groups AS (
SELECT {limited_dimension_names}
FROM SEMANTIC_VIEW(
{self.uid()}
{semantic_view_args_str}
)
ORDER BY
{order_by_clause}
LIMIT {group_limit.top}
)
"""
).strip()
return cte_sql, cte_params
def _build_group_filter(self, group_limit: GroupLimit) -> str:
"""
Build a WHERE filter that restricts results to top N groups.
"""
if len(group_limit.dimensions) == 1:
dimension_id = self._quote(group_limit.dimensions[0].id)
return f"{dimension_id} IN (SELECT {dimension_id} FROM top_groups)"
# Multi-column IN clause
dimension_tuple = ", ".join(
self._quote(dim.id) for dim in group_limit.dimensions
)
return f"({dimension_tuple}) IN (SELECT {dimension_tuple} FROM top_groups)"
def _build_case_expression(
self,
dimension: Dimension,
group_condition: str,
) -> str:
"""
Build a CASE expression that replaces non-top values with 'Other'.
Args:
dimension: The dimension to build the CASE for
group_condition: The condition to check if value is in top groups
(e.g., "staff_id IN (SELECT staff_id FROM top_groups)")
Returns:
SQL CASE expression
"""
dimension_id = self._quote(dimension.id)
return f"""CASE
WHEN {group_condition} THEN {dimension_id}
ELSE CAST('Other' AS VARCHAR)
END"""
def _build_query_with_others(
self,
metrics: list[Metric],
dimensions: list[Dimension],
where_clause: str,
order: list[OrderTuple] | None,
limit: int | None,
offset: int | None,
group_limit: GroupLimit,
) -> tuple[str, tuple[FilterValues, ...]]:
"""
Build a query that groups non-top N values as 'Other'.
This uses a two-stage approach:
1. CTE to find top N groups
2. Subquery with CASE expressions to replace non-top values with 'Other'
3. Outer query to re-aggregate with the new grouping
Returns:
Tuple of (SQL query, CTE parameters)
"""
top_groups_cte, cte_params = self._build_top_groups_cte(
group_limit,
where_clause,
)
# Determine which dimensions are limited vs non-limited
limited_dimension_ids = {dim.id for dim in group_limit.dimensions}
non_limited_dimensions = [
dim for dim in dimensions if dim.id not in limited_dimension_ids
]
# Build the group condition for CASE expressions
if len(group_limit.dimensions) == 1:
dimension_id = self._quote(group_limit.dimensions[0].id)
group_condition = (
f"{dimension_id} IN (SELECT {dimension_id} FROM top_groups)"
)
else:
dimension_tuple = ", ".join(
self._quote(dim.id) for dim in group_limit.dimensions
)
group_condition = (
f"({dimension_tuple}) IN (SELECT {dimension_tuple} FROM top_groups)"
)
# Build CASE expressions for limited dimensions
case_expressions = []
case_expressions_for_groupby = []
for dim in group_limit.dimensions:
case_expr = self._build_case_expression(dim, group_condition)
alias = self._quote(dim.id)
case_expressions.append(f"{case_expr} AS {alias}")
# Store the full CASE expression for GROUP BY (not just alias)
case_expressions_for_groupby.append(case_expr)
# Build SELECT for non-limited dimensions (pass through)
non_limited_selects = [
f"{self._quote(dim.id)} AS {self._quote(dim.id)}"
for dim in non_limited_dimensions
]
# Build metric aggregations
metric_aggregations = [
f"SUM({self._quote(metric.id)}) AS {self._quote(metric.id)}"
for metric in metrics
]
# Build the subquery that gets raw data from SEMANTIC_VIEW
dimension_arguments = ", ".join(
self._alias_element(dimension) for dimension in dimensions
)
metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics)
subquery = dedent(
f"""
raw_data AS (
SELECT * FROM SEMANTIC_VIEW(
{self.uid()}
DIMENSIONS {dimension_arguments}
METRICS {metric_arguments}
{"WHERE " + where_clause if where_clause else ""}
)
)
"""
).strip()
# Build GROUP BY clause (full CASE expressions + non-limited dimensions)
# We need to repeat the full CASE expressions, not use aliases, because
# Snowflake may interpret the alias as the original column reference
group_by_columns = case_expressions_for_groupby + [
self._quote(dim.id) for dim in non_limited_dimensions
]
group_by_clause = ", ".join(group_by_columns)
# Build final SELECT columns
select_columns = case_expressions + non_limited_selects + metric_aggregations
select_clause = ",\n ".join(select_columns)
# Build ORDER BY clause (need to reference the aliased columns)
order_clause = self._build_order_clause(order)
query = dedent(
f"""
{top_groups_cte},
{subquery}
SELECT
{select_clause}
FROM raw_data
GROUP BY {group_by_clause}
{"ORDER BY " + order_clause if order_clause else ""}
{"LIMIT " + str(limit) if limit is not None else ""}
{"OFFSET " + str(offset) if offset is not None else ""}
"""
).strip()
return query, cte_params
def _build_query_with_group_limit(
self,
metrics: list[Metric],
dimensions: list[Dimension],
where_clause: str,
order: list[OrderTuple] | None,
limit: int | None,
offset: int | None,
group_limit: GroupLimit,
) -> tuple[str, tuple[FilterValues, ...]]:
"""
Build a query with group limiting (top N groups).
If group_others is True, groups non-top values as 'Other'.
Otherwise, filters to show only top N groups.
Returns:
Tuple of (SQL query, CTE parameters)
"""
if group_limit.group_others:
return self._build_query_with_others(
metrics,
dimensions,
where_clause,
order,
limit,
offset,
group_limit,
)
# Standard group limiting: just filter to top N groups
# We can't use CTE references inside SEMANTIC_VIEW(), so we wrap it
dimension_arguments = ", ".join(
self._alias_element(dimension) for dimension in dimensions
)
metric_arguments = ", ".join(self._alias_element(metric) for metric in metrics)
order_clause = self._build_order_clause(order)
top_groups_cte, cte_params = self._build_top_groups_cte(
group_limit,
where_clause,
)
group_filter = self._build_group_filter(group_limit)
query = dedent(
f"""
{top_groups_cte}
SELECT * FROM SEMANTIC_VIEW(
{self.uid()}
{"DIMENSIONS " + dimension_arguments if dimension_arguments else ""}
{"METRICS " + metric_arguments if metric_arguments else ""}
{"WHERE " + where_clause if where_clause else ""}
) AS subquery
WHERE {group_filter}
{"ORDER BY " + order_clause if order_clause else ""}
{"LIMIT " + str(limit) if limit is not None else ""}
{"OFFSET " + str(offset) if offset is not None else ""}
"""
).strip()
return query, cte_params
__repr__ = uid

View File

@@ -0,0 +1,123 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: S608
from __future__ import annotations
from typing import Any, Sequence
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from superset.exceptions import SupersetParseError
from superset.semantic_layers.snowflake.schemas import (
PrivateKeyAuth,
SnowflakeConfiguration,
UserPasswordAuth,
)
from superset.sql.parse import SQLStatement
def substitute_parameters(query: str, parameters: Sequence[Any] | None) -> str:
"""
Substitute parametereters in templated query.
This is used to convert bind query parameters so that we can return the executed
query for logging/auditing purposes. With Snowflake the binding happens on the
server, so the only way to get the true executed query would be to query the
database, which is innefficient.
"""
if not parameters:
return query
result = query
for parameter in parameters:
if parameter is None:
replacement = "NULL"
elif isinstance(parameter, bool):
# Check bool before int/float since bool is a subclass of int
replacement = str(parameter).upper()
elif isinstance(parameter, (int, float)):
replacement = str(parameter)
else:
# String - escape single quotes
quoted = str(parameter).replace("'", "''")
replacement = f"'{quoted}'"
result = result.replace("?", replacement, 1)
return result
def validate_order_by(definition: str) -> None:
"""
Validate that an ORDER BY expression is safe to use.
Note that `definition` could contain multiple expressions separated by commas.
"""
try:
# this ensures that we have a single statement, preventing SQL injection via a
# semicolon in the order by clause
SQLStatement(f"SELECT 1 ORDER BY {definition}", "snowflake")
except SupersetParseError as ex:
raise ValueError("Invalid ORDER BY expression") from ex
def get_connection_parameters(configuration: SnowflakeConfiguration) -> dict[str, Any]:
"""
Convert the configuration to connection parameters for the Snowflake connector.
"""
params = {
"account": configuration.account_identifier,
"application": "Apache Superset",
"paramstyle": "qmark",
"insecure_mode": True,
}
if configuration.role:
params["role"] = configuration.role
if configuration.warehouse:
params["warehouse"] = configuration.warehouse
if configuration.database:
params["database"] = configuration.database
if configuration.schema_:
params["schema"] = configuration.schema_
auth = configuration.auth
if isinstance(auth, UserPasswordAuth):
params["user"] = auth.username
params["password"] = auth.password.get_secret_value()
elif isinstance(auth, PrivateKeyAuth):
pem_private_key = serialization.load_pem_private_key(
auth.private_key.get_secret_value().encode(),
password=(
auth.private_key_password.get_secret_value().encode()
if auth.private_key_password
else None
),
backend=default_backend(),
)
params["private_key"] = pem_private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
else:
raise ValueError("Unsupported authentication method")
return params

View File

@@ -0,0 +1,443 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import enum
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
from functools import total_ordering
from typing import Any, Protocol, runtime_checkable, TypeVar
from pandas import DataFrame
from pydantic import BaseModel
__all__ = [
"BINARY",
"BOOLEAN",
"DATE",
"DATETIME",
"DECIMAL",
"DateGrain",
"Dimension",
"INTEGER",
"INTERVAL",
"NUMBER",
"OBJECT",
"STRING",
"TIME",
"TimeGrain",
]
class Type:
"""
Base class for types.
"""
class INTEGER(Type):
"""
Represents an integer type.
"""
class NUMBER(Type):
"""
Represents a number type.
"""
class DECIMAL(Type):
"""
Represents a decimal type.
"""
class STRING(Type):
"""
Represents a string type.
"""
class BOOLEAN(Type):
"""
Represents a boolean type.
"""
class DATE(Type):
"""
Represents a date type.
"""
class TIME(Type):
"""
Represents a time type.
"""
class DATETIME(DATE, TIME):
"""
Represents a datetime type.
"""
class INTERVAL(Type):
"""
Represents an interval type.
"""
class OBJECT(Type):
"""
Represents an object type.
"""
class BINARY(Type):
"""
Represents a binary type.
"""
@total_ordering
class ComparableEnum(enum.Enum):
def __eq__(self, other: object) -> bool:
if isinstance(other, enum.Enum):
return self.value == other.value
return NotImplemented
def __lt__(self, other: object) -> bool:
if isinstance(other, enum.Enum):
return self.value < other.value
return NotImplemented
def __hash__(self) -> int:
return hash((self.__class__, self.name))
class TimeGrain(ComparableEnum):
PT1S = timedelta(seconds=1)
PT1M = timedelta(minutes=1)
PT1H = timedelta(hours=1)
class DateGrain(ComparableEnum):
P1D = timedelta(days=1)
P1W = timedelta(weeks=1)
P1M = timedelta(days=30)
P3M = timedelta(days=90)
P1Y = timedelta(days=365)
@dataclass(frozen=True)
class Dimension:
id: str
name: str
type: type[Type]
definition: str | None = None
description: str | None = None
grain: DateGrain | TimeGrain | None = None
@dataclass(frozen=True)
class Metric:
id: str
name: str
type: type[Type]
definition: str | None
description: str | None = None
@dataclass(frozen=True)
class AdhocExpression:
id: str
definition: str
class Operator(str, enum.Enum):
EQUALS = "="
NOT_EQUALS = "!="
GREATER_THAN = ">"
LESS_THAN = "<"
GREATER_THAN_OR_EQUAL = ">="
LESS_THAN_OR_EQUAL = "<="
IN = "IN"
NOT_IN = "NOT IN"
LIKE = "LIKE"
NOT_LIKE = "NOT LIKE"
IS_NULL = "IS NULL"
IS_NOT_NULL = "IS NOT NULL"
FilterValues = str | int | float | bool | datetime | date | time | timedelta | None
class PredicateType(enum.Enum):
WHERE = "WHERE"
HAVING = "HAVING"
@dataclass(frozen=True, order=True)
class Filter:
type: PredicateType
column: Dimension | Metric
operator: Operator
value: FilterValues | set[FilterValues]
@dataclass(frozen=True, order=True)
class AdhocFilter:
type: PredicateType
definition: str
class OrderDirection(enum.Enum):
ASC = "ASC"
DESC = "DESC"
OrderTuple = tuple[Metric | Dimension | AdhocExpression, OrderDirection]
@dataclass(frozen=True)
class GroupLimit:
"""
Limit query to top/bottom N combinations of specified dimensions.
The `filters` parameter allows specifying separate filter constraints for the
group limit subquery. This is useful when you want to determine the top N groups
using different criteria (e.g., a different time range) than the main query.
For example, you might want to find the top 10 products by sales over the last
30 days, but then show daily sales for those products over the last 7 days.
"""
dimensions: list[Dimension]
top: int
metric: Metric | None
direction: OrderDirection = OrderDirection.DESC
group_others: bool = False
filters: set[Filter | AdhocFilter] | None = None
@dataclass(frozen=True)
class SemanticRequest:
"""
Represents a request made to obtain semantic results.
This could be a SQL query, an HTTP request, etc.
"""
type: str
definition: str
@dataclass(frozen=True)
class SemanticResult:
"""
Represents the results of a semantic query.
This includes any requests (SQL queries, HTTP requests) that were performed in order
to obtain the results, in order to help troubleshooting.
"""
requests: list[SemanticRequest]
results: DataFrame
@dataclass(frozen=True)
class SemanticQuery:
"""
Represents a semantic query.
"""
metrics: list[Metric]
dimensions: list[Dimension]
filters: set[Filter | AdhocFilter] | None = None
order: list[OrderTuple] | None = None
limit: int | None = None
offset: int | None = None
group_limit: GroupLimit | None = None
class SemanticViewFeature(enum.Enum):
"""
Custom features supported by semantic layers.
"""
ADHOC_EXPRESSIONS_IN_ORDERBY = "ADHOC_EXPRESSIONS_IN_ORDERBY"
GROUP_LIMIT = "GROUP_LIMIT"
GROUP_OTHERS = "GROUP_OTHERS"
ConfigT = TypeVar("ConfigT", bound=BaseModel, contravariant=True)
SemanticViewT = TypeVar("SemanticViewT", bound="SemanticViewImplementation")
@runtime_checkable
class SemanticLayerImplementation(Protocol[ConfigT, SemanticViewT]):
"""
A protocol for semantic layers.
"""
@classmethod
def from_configuration(
cls,
configuration: dict[str, Any],
) -> SemanticLayerImplementation[ConfigT, SemanticViewT]:
"""
Create a semantic layer from its configuration.
"""
@classmethod
def get_configuration_schema(
cls,
configuration: ConfigT | None = None,
) -> dict[str, Any]:
"""
Get the JSON schema for the configuration needed to add the semantic layer.
A partial configuration `configuration` can be sent to improve the schema,
allowing for progressive validation and better UX. For example, a semantic
layer might require:
- auth information
- a database
If the user provides the auth information, a client can send the partial
configuration to this method, and the resulting JSON schema would include
the list of databases the user has access to, allowing a dropdown to be
populated.
The Snowflake semantic layer has an example implementation of this method, where
database and schema names are populated based on the provided connection info.
"""
@classmethod
def get_runtime_schema(
cls,
configuration: ConfigT,
runtime_data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Get the JSON schema for the runtime parameters needed to load semantic views.
This returns the schema needed to connect to a semantic view given the
configuration for the semantic layer. For example, a semantic layer might
be configured by:
- auth information
- an optional database
If the user does not provide a database when creating the semantic layer, the
runtime schema would require the database name to be provided before loading any
semantic views. This allows users to create semantic layers that connect to a
specific database (or project, account, etc.), or that allow users to select it
at query time.
The Snowflake semantic layer has an example implementation of this method, where
database and schema names are required if they were not provided in the initial
configuration.
"""
def get_semantic_views(
self,
runtime_configuration: dict[str, Any],
) -> set[SemanticViewT]:
"""
Get the semantic views available in the semantic layer.
The runtime configuration can provide information like a given project or
schema, used to restrict the semantic views returned.
"""
def get_semantic_view(
self,
name: str,
additional_configuration: dict[str, Any],
) -> SemanticViewT:
"""
Get a specific semantic view by its name and additional configuration.
"""
@runtime_checkable
class SemanticViewImplementation(Protocol):
"""
A protocol for semantic views.
"""
features: frozenset[SemanticViewFeature]
def uid(self) -> str:
"""
Returns a unique identifier for the semantic view.
"""
def get_dimensions(self) -> set[Dimension]:
"""
Get the dimensions defined in the semantic view.
"""
def get_metrics(self) -> set[Metric]:
"""
Get the metrics defined in the semantic view.
"""
def get_values(
self,
dimension: Dimension,
filters: set[Filter | AdhocFilter] | None = None,
) -> SemanticResult:
"""
Return distinct values for a dimension.
"""
def get_dataframe(
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
*,
group_limit: GroupLimit | None = None,
) -> SemanticResult:
"""
Execute a semantic query and return the results as a DataFrame.
"""
def get_row_count(
self,
metrics: list[Metric],
dimensions: list[Dimension],
filters: set[Filter | AdhocFilter] | None = None,
order: list[OrderTuple] | None = None,
limit: int | None = None,
offset: int | None = None,
*,
group_limit: GroupLimit | None = None,
) -> SemanticResult:
"""
Execute a query and return the number of rows the result would have.
"""

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,166 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for CreateSemanticLayerCommand."""
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.commands.semantic_layer.create import CreateSemanticLayerCommand
from superset.commands.semantic_layer.exceptions import (
SemanticLayerExistsValidationError,
SemanticLayerInvalidError,
SemanticLayerRequiredFieldValidationError,
)
def test_create_semantic_layer_success(mocker: MockerFixture) -> None:
"""
Test successful semantic layer creation.
"""
mock_layer = MagicMock()
mock_layer.uuid = "test-uuid"
mock_layer.name = "test_layer"
dao = mocker.patch("superset.commands.semantic_layer.create.SemanticLayerDAO")
dao.validate_uniqueness.return_value = True
dao.create.return_value = mock_layer
properties = {
"name": "test_layer",
"type": "cube",
"configuration": '{"url": "http://localhost:4000"}',
}
command = CreateSemanticLayerCommand(properties)
result = command.run()
assert result == mock_layer
dao.create.assert_called_once_with(attributes=properties)
def test_create_semantic_layer_missing_name(mocker: MockerFixture) -> None:
"""
Test create fails when name is missing.
"""
mocker.patch("superset.commands.semantic_layer.create.SemanticLayerDAO")
properties = {
"type": "cube",
"configuration": '{"url": "http://localhost:4000"}',
}
command = CreateSemanticLayerCommand(properties)
with pytest.raises(SemanticLayerInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 1
assert isinstance(
exc_info.value._exceptions[0], SemanticLayerRequiredFieldValidationError
)
def test_create_semantic_layer_missing_type(mocker: MockerFixture) -> None:
"""
Test create fails when type is missing.
"""
mocker.patch("superset.commands.semantic_layer.create.SemanticLayerDAO")
properties = {
"name": "test_layer",
"configuration": '{"url": "http://localhost:4000"}',
}
command = CreateSemanticLayerCommand(properties)
with pytest.raises(SemanticLayerInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 1
assert isinstance(
exc_info.value._exceptions[0], SemanticLayerRequiredFieldValidationError
)
def test_create_semantic_layer_duplicate_name(mocker: MockerFixture) -> None:
"""
Test create fails when name already exists.
"""
dao = mocker.patch("superset.commands.semantic_layer.create.SemanticLayerDAO")
dao.validate_uniqueness.return_value = False
properties = {
"name": "existing_layer",
"type": "cube",
"configuration": '{"url": "http://localhost:4000"}',
}
command = CreateSemanticLayerCommand(properties)
with pytest.raises(SemanticLayerInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 1
assert isinstance(exc_info.value._exceptions[0], SemanticLayerExistsValidationError)
def test_create_semantic_layer_multiple_errors(mocker: MockerFixture) -> None:
"""
Test create accumulates multiple validation errors.
"""
mocker.patch("superset.commands.semantic_layer.create.SemanticLayerDAO")
properties = {
"configuration": '{"url": "http://localhost:4000"}',
}
command = CreateSemanticLayerCommand(properties)
with pytest.raises(SemanticLayerInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 2
def test_create_semantic_layer_with_optional_fields(mocker: MockerFixture) -> None:
"""
Test create with optional fields.
"""
mock_layer = MagicMock()
mock_layer.uuid = "test-uuid"
mock_layer.name = "test_layer"
dao = mocker.patch("superset.commands.semantic_layer.create.SemanticLayerDAO")
dao.validate_uniqueness.return_value = True
dao.create.return_value = mock_layer
properties = {
"name": "test_layer",
"type": "cube",
"description": "Test description",
"configuration": '{"url": "http://localhost:4000"}',
"cache_timeout": 3600,
}
command = CreateSemanticLayerCommand(properties)
result = command.run()
assert result == mock_layer
dao.create.assert_called_once_with(attributes=properties)

View File

@@ -0,0 +1,82 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for DeleteSemanticLayerCommand."""
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.commands.semantic_layer.delete import DeleteSemanticLayerCommand
from superset.commands.semantic_layer.exceptions import SemanticLayerNotFoundError
def test_delete_semantic_layer_success(mocker: MockerFixture) -> None:
"""
Test successful semantic layer deletion.
"""
mock_layer = MagicMock()
mock_layer.uuid = "test-uuid"
mock_layer.name = "test_layer"
dao = mocker.patch("superset.commands.semantic_layer.delete.SemanticLayerDAO")
dao.find_by_id.return_value = mock_layer
dao.delete.return_value = None
command = DeleteSemanticLayerCommand("test-uuid")
result = command.run()
assert result is None
dao.delete.assert_called_once_with([mock_layer])
def test_delete_semantic_layer_not_found(mocker: MockerFixture) -> None:
"""
Test delete fails when semantic layer not found.
"""
dao = mocker.patch("superset.commands.semantic_layer.delete.SemanticLayerDAO")
dao.find_by_id.return_value = None
command = DeleteSemanticLayerCommand("nonexistent-uuid")
with pytest.raises(SemanticLayerNotFoundError):
command.run()
def test_delete_semantic_layer_cascades_views(mocker: MockerFixture) -> None:
"""
Test delete cascades to semantic views.
"""
mock_layer = MagicMock()
mock_layer.uuid = "test-uuid"
mock_layer.name = "test_layer"
# Mock semantic views that will be cascade deleted
mock_view1 = MagicMock()
mock_view2 = MagicMock()
mock_layer.semantic_views = [mock_view1, mock_view2]
dao = mocker.patch("superset.commands.semantic_layer.delete.SemanticLayerDAO")
dao.find_by_id.return_value = mock_layer
dao.delete.return_value = None
command = DeleteSemanticLayerCommand("test-uuid")
result = command.run()
assert result is None
dao.delete.assert_called_once_with([mock_layer])

View File

@@ -0,0 +1,166 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for UpdateSemanticLayerCommand."""
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.commands.semantic_layer.exceptions import (
SemanticLayerExistsValidationError,
SemanticLayerInvalidError,
SemanticLayerNotFoundError,
)
from superset.commands.semantic_layer.update import UpdateSemanticLayerCommand
def test_update_semantic_layer_success(mocker: MockerFixture) -> None:
"""
Test successful semantic layer update.
"""
mock_layer = MagicMock()
mock_layer.uuid = "test-uuid"
mock_layer.name = "test_layer"
dao = mocker.patch("superset.commands.semantic_layer.update.SemanticLayerDAO")
dao.find_by_id.return_value = mock_layer
dao.validate_update_uniqueness.return_value = True
dao.update.return_value = mock_layer
properties = {
"description": "Updated description",
"cache_timeout": 7200,
}
command = UpdateSemanticLayerCommand("test-uuid", properties)
result = command.run()
assert result == mock_layer
dao.update.assert_called_once_with(mock_layer, properties)
def test_update_semantic_layer_not_found(mocker: MockerFixture) -> None:
"""
Test update fails when semantic layer not found.
"""
dao = mocker.patch("superset.commands.semantic_layer.update.SemanticLayerDAO")
dao.find_by_id.return_value = None
properties = {"description": "Updated description"}
command = UpdateSemanticLayerCommand("nonexistent-uuid", properties)
with pytest.raises(SemanticLayerNotFoundError):
command.run()
def test_update_semantic_layer_duplicate_name(mocker: MockerFixture) -> None:
"""
Test update fails when new name already exists.
"""
mock_layer = MagicMock()
mock_layer.uuid = "test-uuid"
mock_layer.name = "test_layer"
dao = mocker.patch("superset.commands.semantic_layer.update.SemanticLayerDAO")
dao.find_by_id.return_value = mock_layer
dao.validate_update_uniqueness.return_value = False
properties = {"name": "existing_layer"}
command = UpdateSemanticLayerCommand("test-uuid", properties)
with pytest.raises(SemanticLayerInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 1
assert isinstance(exc_info.value._exceptions[0], SemanticLayerExistsValidationError)
def test_update_semantic_layer_name_unchanged(mocker: MockerFixture) -> None:
"""
Test update with same name doesn't trigger uniqueness validation.
"""
mock_layer = MagicMock()
mock_layer.uuid = "test-uuid"
mock_layer.name = "test_layer"
dao = mocker.patch("superset.commands.semantic_layer.update.SemanticLayerDAO")
dao.find_by_id.return_value = mock_layer
dao.update.return_value = mock_layer
properties = {"description": "Updated description"}
command = UpdateSemanticLayerCommand("test-uuid", properties)
result = command.run()
assert result == mock_layer
dao.validate_update_uniqueness.assert_not_called()
def test_update_semantic_layer_name_changed(mocker: MockerFixture) -> None:
"""
Test update with new name triggers uniqueness validation.
"""
mock_layer = MagicMock()
mock_layer.uuid = "test-uuid"
mock_layer.name = "test_layer"
dao = mocker.patch("superset.commands.semantic_layer.update.SemanticLayerDAO")
dao.find_by_id.return_value = mock_layer
dao.validate_update_uniqueness.return_value = True
dao.update.return_value = mock_layer
properties = {"name": "new_layer_name"}
command = UpdateSemanticLayerCommand("test-uuid", properties)
result = command.run()
assert result == mock_layer
dao.validate_update_uniqueness.assert_called_once_with(
"test-uuid", "new_layer_name"
)
def test_update_semantic_layer_all_fields(mocker: MockerFixture) -> None:
"""
Test update with all fields.
"""
mock_layer = MagicMock()
mock_layer.uuid = "test-uuid"
mock_layer.name = "test_layer"
dao = mocker.patch("superset.commands.semantic_layer.update.SemanticLayerDAO")
dao.find_by_id.return_value = mock_layer
dao.validate_update_uniqueness.return_value = True
dao.update.return_value = mock_layer
properties = {
"name": "updated_layer",
"description": "Updated description",
"type": "dbt",
"configuration": '{"token": "new-token"}',
"cache_timeout": 7200,
}
command = UpdateSemanticLayerCommand("test-uuid", properties)
result = command.run()
assert result == mock_layer
dao.update.assert_called_once_with(mock_layer, properties)

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,210 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for CreateSemanticViewCommand."""
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.commands.semantic_layer.exceptions import SemanticLayerNotFoundError
from superset.commands.semantic_view.create import CreateSemanticViewCommand
from superset.commands.semantic_view.exceptions import (
SemanticViewExistsValidationError,
SemanticViewInvalidError,
SemanticViewRequiredFieldValidationError,
)
def test_create_semantic_view_success(mocker: MockerFixture) -> None:
"""
Test successful semantic view creation.
"""
mock_layer = MagicMock()
mock_layer.uuid = "layer-uuid"
mock_view = MagicMock()
mock_view.uuid = "view-uuid"
mock_view.name = "test_view"
layer_dao = mocker.patch("superset.commands.semantic_view.create.SemanticLayerDAO")
layer_dao.find_by_id.return_value = mock_layer
view_dao = mocker.patch("superset.commands.semantic_view.create.SemanticViewDAO")
view_dao.validate_uniqueness.return_value = True
view_dao.create.return_value = mock_view
properties = {
"name": "test_view",
"semantic_layer_uuid": "layer-uuid",
"configuration": '{"columns": ["id", "name"]}',
}
command = CreateSemanticViewCommand(properties)
result = command.run()
assert result == mock_view
view_dao.create.assert_called_once_with(attributes=properties)
def test_create_semantic_view_missing_name(mocker: MockerFixture) -> None:
"""
Test create fails when name is missing.
"""
mocker.patch("superset.commands.semantic_view.create.SemanticLayerDAO")
mocker.patch("superset.commands.semantic_view.create.SemanticViewDAO")
properties = {
"semantic_layer_uuid": "layer-uuid",
"configuration": '{"columns": ["id"]}',
}
command = CreateSemanticViewCommand(properties)
with pytest.raises(SemanticViewInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 1
assert isinstance(
exc_info.value._exceptions[0], SemanticViewRequiredFieldValidationError
)
def test_create_semantic_view_missing_semantic_layer_uuid(
mocker: MockerFixture,
) -> None:
"""
Test create fails when semantic_layer_uuid is missing.
"""
mocker.patch("superset.commands.semantic_view.create.SemanticLayerDAO")
mocker.patch("superset.commands.semantic_view.create.SemanticViewDAO")
properties = {
"name": "test_view",
"configuration": '{"columns": ["id"]}',
}
command = CreateSemanticViewCommand(properties)
with pytest.raises(SemanticViewInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 1
assert isinstance(
exc_info.value._exceptions[0], SemanticViewRequiredFieldValidationError
)
def test_create_semantic_view_semantic_layer_not_found(mocker: MockerFixture) -> None:
"""
Test create fails when semantic layer not found.
"""
layer_dao = mocker.patch("superset.commands.semantic_view.create.SemanticLayerDAO")
layer_dao.find_by_id.return_value = None
mocker.patch("superset.commands.semantic_view.create.SemanticViewDAO")
properties = {
"name": "test_view",
"semantic_layer_uuid": "nonexistent-uuid",
"configuration": '{"columns": ["id"]}',
}
command = CreateSemanticViewCommand(properties)
with pytest.raises(SemanticLayerNotFoundError):
command.run()
def test_create_semantic_view_duplicate_name(mocker: MockerFixture) -> None:
"""
Test create fails when name already exists in layer.
"""
mock_layer = MagicMock()
mock_layer.uuid = "layer-uuid"
layer_dao = mocker.patch("superset.commands.semantic_view.create.SemanticLayerDAO")
layer_dao.find_by_id.return_value = mock_layer
view_dao = mocker.patch("superset.commands.semantic_view.create.SemanticViewDAO")
view_dao.validate_uniqueness.return_value = False
properties = {
"name": "existing_view",
"semantic_layer_uuid": "layer-uuid",
"configuration": '{"columns": ["id"]}',
}
command = CreateSemanticViewCommand(properties)
with pytest.raises(SemanticViewInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 1
assert isinstance(exc_info.value._exceptions[0], SemanticViewExistsValidationError)
def test_create_semantic_view_multiple_errors(mocker: MockerFixture) -> None:
"""
Test create accumulates multiple validation errors.
"""
mocker.patch("superset.commands.semantic_view.create.SemanticLayerDAO")
mocker.patch("superset.commands.semantic_view.create.SemanticViewDAO")
properties = {
"configuration": '{"columns": ["id"]}',
}
command = CreateSemanticViewCommand(properties)
with pytest.raises(SemanticViewInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 2
def test_create_semantic_view_with_optional_fields(mocker: MockerFixture) -> None:
"""
Test create with optional fields.
"""
mock_layer = MagicMock()
mock_layer.uuid = "layer-uuid"
mock_view = MagicMock()
mock_view.uuid = "view-uuid"
mock_view.name = "test_view"
layer_dao = mocker.patch("superset.commands.semantic_view.create.SemanticLayerDAO")
layer_dao.find_by_id.return_value = mock_layer
view_dao = mocker.patch("superset.commands.semantic_view.create.SemanticViewDAO")
view_dao.validate_uniqueness.return_value = True
view_dao.create.return_value = mock_view
properties = {
"name": "test_view",
"semantic_layer_uuid": "layer-uuid",
"configuration": '{"columns": ["id", "name"]}',
"cache_timeout": 1800,
}
command = CreateSemanticViewCommand(properties)
result = command.run()
assert result == mock_view
view_dao.create.assert_called_once_with(attributes=properties)

View File

@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for DeleteSemanticViewCommand."""
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.commands.semantic_view.delete import DeleteSemanticViewCommand
from superset.commands.semantic_view.exceptions import SemanticViewNotFoundError
def test_delete_semantic_view_success(mocker: MockerFixture) -> None:
"""
Test successful semantic view deletion.
"""
mock_view = MagicMock()
mock_view.uuid = "view-uuid"
mock_view.name = "test_view"
dao = mocker.patch("superset.commands.semantic_view.delete.SemanticViewDAO")
dao.find_by_id.return_value = mock_view
dao.delete.return_value = None
command = DeleteSemanticViewCommand("view-uuid")
result = command.run()
assert result is None
dao.delete.assert_called_once_with([mock_view])
def test_delete_semantic_view_not_found(mocker: MockerFixture) -> None:
"""
Test delete fails when semantic view not found.
"""
dao = mocker.patch("superset.commands.semantic_view.delete.SemanticViewDAO")
dao.find_by_id.return_value = None
command = DeleteSemanticViewCommand("nonexistent-uuid")
with pytest.raises(SemanticViewNotFoundError):
command.run()

View File

@@ -0,0 +1,169 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for UpdateSemanticViewCommand."""
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.commands.semantic_view.exceptions import (
SemanticViewExistsValidationError,
SemanticViewInvalidError,
SemanticViewNotFoundError,
)
from superset.commands.semantic_view.update import UpdateSemanticViewCommand
def test_update_semantic_view_success(mocker: MockerFixture) -> None:
"""
Test successful semantic view update.
"""
mock_view = MagicMock()
mock_view.uuid = "view-uuid"
mock_view.name = "test_view"
mock_view.semantic_layer_uuid = "layer-uuid"
dao = mocker.patch("superset.commands.semantic_view.update.SemanticViewDAO")
dao.find_by_id.return_value = mock_view
dao.validate_update_uniqueness.return_value = True
dao.update.return_value = mock_view
properties = {
"configuration": '{"columns": ["id", "name", "email"]}',
"cache_timeout": 3600,
}
command = UpdateSemanticViewCommand("view-uuid", properties)
result = command.run()
assert result == mock_view
dao.update.assert_called_once_with(mock_view, properties)
def test_update_semantic_view_not_found(mocker: MockerFixture) -> None:
"""
Test update fails when semantic view not found.
"""
dao = mocker.patch("superset.commands.semantic_view.update.SemanticViewDAO")
dao.find_by_id.return_value = None
properties = {"configuration": '{"columns": ["id"]}'}
command = UpdateSemanticViewCommand("nonexistent-uuid", properties)
with pytest.raises(SemanticViewNotFoundError):
command.run()
def test_update_semantic_view_duplicate_name(mocker: MockerFixture) -> None:
"""
Test update fails when new name already exists in layer.
"""
mock_view = MagicMock()
mock_view.uuid = "view-uuid"
mock_view.name = "test_view"
mock_view.semantic_layer_uuid = "layer-uuid"
dao = mocker.patch("superset.commands.semantic_view.update.SemanticViewDAO")
dao.find_by_id.return_value = mock_view
dao.validate_update_uniqueness.return_value = False
properties = {"name": "existing_view"}
command = UpdateSemanticViewCommand("view-uuid", properties)
with pytest.raises(SemanticViewInvalidError) as exc_info:
command.run()
assert len(exc_info.value._exceptions) == 1
assert isinstance(exc_info.value._exceptions[0], SemanticViewExistsValidationError)
def test_update_semantic_view_name_unchanged(mocker: MockerFixture) -> None:
"""
Test update with same name doesn't trigger uniqueness validation.
"""
mock_view = MagicMock()
mock_view.uuid = "view-uuid"
mock_view.name = "test_view"
mock_view.semantic_layer_uuid = "layer-uuid"
dao = mocker.patch("superset.commands.semantic_view.update.SemanticViewDAO")
dao.find_by_id.return_value = mock_view
dao.update.return_value = mock_view
properties = {"configuration": '{"columns": ["id", "name"]}'}
command = UpdateSemanticViewCommand("view-uuid", properties)
result = command.run()
assert result == mock_view
dao.validate_update_uniqueness.assert_not_called()
def test_update_semantic_view_name_changed(mocker: MockerFixture) -> None:
"""
Test update with new name triggers uniqueness validation.
"""
mock_view = MagicMock()
mock_view.uuid = "view-uuid"
mock_view.name = "test_view"
mock_view.semantic_layer_uuid = "layer-uuid"
dao = mocker.patch("superset.commands.semantic_view.update.SemanticViewDAO")
dao.find_by_id.return_value = mock_view
dao.validate_update_uniqueness.return_value = True
dao.update.return_value = mock_view
properties = {"name": "new_view_name"}
command = UpdateSemanticViewCommand("view-uuid", properties)
result = command.run()
assert result == mock_view
dao.validate_update_uniqueness.assert_called_once_with(
"view-uuid", "new_view_name", "layer-uuid"
)
def test_update_semantic_view_all_fields(mocker: MockerFixture) -> None:
"""
Test update with all fields.
"""
mock_view = MagicMock()
mock_view.uuid = "view-uuid"
mock_view.name = "test_view"
mock_view.semantic_layer_uuid = "layer-uuid"
dao = mocker.patch("superset.commands.semantic_view.update.SemanticViewDAO")
dao.find_by_id.return_value = mock_view
dao.validate_update_uniqueness.return_value = True
dao.update.return_value = mock_view
properties = {
"name": "updated_view",
"configuration": '{"columns": ["id", "name", "email"]}',
"cache_timeout": 3600,
}
command = UpdateSemanticViewCommand("view-uuid", properties)
result = command.run()
assert result == mock_view
dao.update.assert_called_once_with(mock_view, properties)

View File

@@ -0,0 +1,305 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for semantic layer DAOs."""
from __future__ import annotations
from typing import Iterator
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from superset.daos.semantic_layer import SemanticLayerDAO, SemanticViewDAO
from superset.semantic_layers.models import SemanticLayer, SemanticView
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
"""
Create session with semantic layer test data.
"""
engine = session.get_bind()
SemanticLayer.metadata.create_all(engine)
layer1 = SemanticLayer(
uuid=uuid4(),
name="layer1",
description="First layer",
type="cube",
configuration='{"url": "http://localhost:4000"}',
cache_timeout=3600,
)
layer2 = SemanticLayer(
uuid=uuid4(),
name="layer2",
description="Second layer",
type="dbt",
configuration='{"token": "secret"}',
)
session.add_all([layer1, layer2])
session.flush()
view1 = SemanticView(
uuid=uuid4(),
name="view1",
configuration='{"columns": ["id", "name"]}',
cache_timeout=1800,
semantic_layer_uuid=layer1.uuid,
)
view2 = SemanticView(
uuid=uuid4(),
name="view2",
configuration='{"columns": ["id", "value"]}',
semantic_layer_uuid=layer1.uuid,
)
view3 = SemanticView(
uuid=uuid4(),
name="view1",
configuration='{"columns": ["id"]}',
semantic_layer_uuid=layer2.uuid,
)
session.add_all([view1, view2, view3])
session.flush()
yield session
session.rollback()
def test_semantic_layer_find_by_name(session_with_data: Session) -> None:
"""
Test finding semantic layer by name.
"""
result = SemanticLayerDAO.find_by_name("layer1")
assert result is not None
assert result.name == "layer1"
assert result.description == "First layer"
def test_semantic_layer_find_by_name_not_found(session_with_data: Session) -> None:
"""
Test finding non-existent semantic layer by name.
"""
result = SemanticLayerDAO.find_by_name("nonexistent")
assert result is None
def test_semantic_layer_validate_uniqueness_true(session_with_data: Session) -> None:
"""
Test validating uniqueness returns True for new name.
"""
result = SemanticLayerDAO.validate_uniqueness("new_layer")
assert result is True
def test_semantic_layer_validate_uniqueness_false(session_with_data: Session) -> None:
"""
Test validating uniqueness returns False for existing name.
"""
result = SemanticLayerDAO.validate_uniqueness("layer1")
assert result is False
def test_semantic_layer_validate_update_uniqueness_same_name(
session_with_data: Session,
) -> None:
"""
Test validating update uniqueness allows keeping same name.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
result = SemanticLayerDAO.validate_update_uniqueness(str(layer.uuid), "layer1")
assert result is True
def test_semantic_layer_validate_update_uniqueness_new_name(
session_with_data: Session,
) -> None:
"""
Test validating update uniqueness allows new unique name.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
result = SemanticLayerDAO.validate_update_uniqueness(str(layer.uuid), "new_name")
assert result is True
def test_semantic_layer_validate_update_uniqueness_existing_name(
session_with_data: Session,
) -> None:
"""
Test validating update uniqueness rejects existing name.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
result = SemanticLayerDAO.validate_update_uniqueness(str(layer.uuid), "layer2")
assert result is False
def test_semantic_layer_get_semantic_views(session_with_data: Session) -> None:
"""
Test getting all semantic views for a layer.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
views = SemanticLayerDAO.get_semantic_views(layer.uuid)
assert len(views) == 2
assert views[0].name in ["view1", "view2"]
assert views[1].name in ["view1", "view2"]
def test_semantic_view_find_by_semantic_layer(session_with_data: Session) -> None:
"""
Test finding all views for a semantic layer.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
views = SemanticViewDAO.find_by_semantic_layer(layer.uuid)
assert len(views) == 2
assert all(view.semantic_layer_uuid == layer.uuid for view in views)
def test_semantic_view_find_by_name(session_with_data: Session) -> None:
"""
Test finding semantic view by name within layer.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
view = SemanticViewDAO.find_by_name("view1", layer.uuid)
assert view is not None
assert view.name == "view1"
assert view.semantic_layer_uuid == layer.uuid
def test_semantic_view_find_by_name_not_found(session_with_data: Session) -> None:
"""
Test finding non-existent semantic view by name.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
view = SemanticViewDAO.find_by_name("nonexistent", layer.uuid)
assert view is None
def test_semantic_view_validate_uniqueness_true(session_with_data: Session) -> None:
"""
Test validating uniqueness returns True for new name in layer.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
result = SemanticViewDAO.validate_uniqueness("new_view", layer.uuid)
assert result is True
def test_semantic_view_validate_uniqueness_false(session_with_data: Session) -> None:
"""
Test validating uniqueness returns False for existing name in layer.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
result = SemanticViewDAO.validate_uniqueness("view1", layer.uuid)
assert result is False
def test_semantic_view_validate_uniqueness_different_layer(
session_with_data: Session,
) -> None:
"""
Test validating uniqueness allows same name in different layer.
"""
layer2 = session_with_data.query(SemanticLayer).filter_by(name="layer2").first()
assert layer2 is not None
# view1 exists in layer1, but we're checking layer2 where view1 also exists
# So this should return False
result = SemanticViewDAO.validate_uniqueness("view1", layer2.uuid)
assert result is False
def test_semantic_view_validate_update_uniqueness_same_name(
session_with_data: Session,
) -> None:
"""
Test validating update uniqueness allows keeping same name.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
view = (
session_with_data.query(SemanticView)
.filter_by(name="view1", semantic_layer_uuid=layer.uuid)
.first()
)
assert view is not None
result = SemanticViewDAO.validate_update_uniqueness(view.uuid, "view1", layer.uuid)
assert result is True
def test_semantic_view_validate_update_uniqueness_new_name(
session_with_data: Session,
) -> None:
"""
Test validating update uniqueness allows new unique name.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
view = (
session_with_data.query(SemanticView)
.filter_by(name="view1", semantic_layer_uuid=layer.uuid)
.first()
)
assert view is not None
result = SemanticViewDAO.validate_update_uniqueness(
view.uuid, "new_view", layer.uuid
)
assert result is True
def test_semantic_view_validate_update_uniqueness_existing_name(
session_with_data: Session,
) -> None:
"""
Test validating update uniqueness rejects existing name in same layer.
"""
layer = session_with_data.query(SemanticLayer).filter_by(name="layer1").first()
assert layer is not None
view = (
session_with_data.query(SemanticView)
.filter_by(name="view1", semantic_layer_uuid=layer.uuid)
.first()
)
assert view is not None
result = SemanticViewDAO.validate_update_uniqueness(view.uuid, "view2", layer.uuid)
assert result is False

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,356 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# flake8: noqa: E501
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from superset.semantic_layers.snowflake import (
SnowflakeConfiguration,
SnowflakeSemanticLayer,
)
@pytest.mark.parametrize(
"configuration, databases, schemas, expected_db_enum, expected_schema_enum",
[
# No configuration - empty enums
(
None,
None,
None,
[],
[],
),
# Configuration with account + auth - populates databases
(
{
"account_identifier": "test_account",
"auth": {
"auth_type": "user_password",
"username": "test_user",
"password": "test_password",
},
"allow_changing_database": True,
"allow_changing_schema": True,
},
["ANALYTICS_DB", "SALES_DB", "MARKETING_DB"],
None,
["ANALYTICS_DB", "SALES_DB", "MARKETING_DB"],
[],
),
# Configuration with account + auth + database - populates both
(
{
"account_identifier": "test_account",
"database": "ANALYTICS_DB",
"auth": {
"auth_type": "user_password",
"username": "test_user",
"password": "test_password",
},
"allow_changing_schema": True,
},
["ANALYTICS_DB", "SALES_DB", "MARKETING_DB"],
["PUBLIC", "STAGING", "DEV"],
["ANALYTICS_DB", "SALES_DB", "MARKETING_DB"],
["PUBLIC", "STAGING", "DEV"],
),
# Configuration with account + auth, single database
(
{
"account_identifier": "prod_account",
"auth": {
"auth_type": "user_password",
"username": "admin",
"password": "secret",
},
"allow_changing_database": True,
"allow_changing_schema": True,
},
["PRODUCTION"],
None,
["PRODUCTION"],
[],
),
],
)
def test_get_configuration_schema(
configuration: dict[str, Any] | None,
databases: list[str] | None,
schemas: list[str] | None,
expected_db_enum: list[str],
expected_schema_enum: list[str],
) -> None:
"""
Test configuration schema generation with dynamic database/schema enums.
"""
if configuration is None:
# Test without configuration
schema = SnowflakeSemanticLayer.get_configuration_schema()
assert "properties" in schema
assert "database" in schema["properties"]
assert "schema" in schema["properties"]
assert schema["properties"]["database"]["enum"] == expected_db_enum
assert schema["properties"]["schema"]["enum"] == expected_schema_enum
else:
# Create configuration
config = SnowflakeConfiguration(**configuration)
# Mock the connection and cursor
mock_cursor = MagicMock()
mock_connection = MagicMock()
mock_connection.cursor.return_value = mock_cursor
# Setup cursor responses
if databases:
# SHOW DATABASES returns (name, name, ...)
mock_cursor.__iter__.return_value = iter(
[(i, db, "", "", "", "", "") for i, db in enumerate(databases)]
)
if schemas:
# SELECT SCHEMA_NAME returns (schema_name,)
mock_cursor.execute.return_value = iter([(schema,) for schema in schemas])
# Mock connect to return our mock connection
with patch(
"superset.semantic_layers.snowflake.semantic_layer.connect"
) as mock_connect:
mock_connect.return_value.__enter__.return_value = mock_connection
# Get the schema
schema = SnowflakeSemanticLayer.get_configuration_schema(config)
# Verify connect was called
mock_connect.assert_called_once()
# Verify schema structure
assert "properties" in schema
assert "database" in schema["properties"]
assert "schema" in schema["properties"]
# Verify database enum (compare as sets since order isn't guaranteed)
assert set(schema["properties"]["database"]["enum"]) == set(
expected_db_enum
)
# Verify schema enum (may not have 'enum' key if database not set)
if expected_schema_enum:
assert set(schema["properties"]["schema"]["enum"]) == set(
expected_schema_enum
)
else:
# When no schemas are expected, enum key may not exist
# or may be an empty list
schema_enum = schema["properties"]["schema"].get("enum", [])
assert set(schema_enum) == set(expected_schema_enum)
@pytest.mark.parametrize(
"configuration, runtime_data, databases, schemas, expect_database, expect_schema",
[
# Database + schema configured, no changing allowed -> empty runtime schema
(
{
"account_identifier": "test_account",
"database": "ANALYTICS_DB",
"schema": "PUBLIC",
"auth": {
"auth_type": "user_password",
"username": "test_user",
"password": "test_password",
},
"allow_changing_database": False,
"allow_changing_schema": False,
},
None,
None,
None,
False,
False,
),
# Database configured, schema not configured -> shows schemas
(
{
"account_identifier": "test_account",
"database": "ANALYTICS_DB",
"auth": {
"auth_type": "user_password",
"username": "test_user",
"password": "test_password",
},
"allow_changing_schema": True,
},
None,
None,
["PUBLIC", "STAGING", "DEV"],
False,
True,
),
# Database configured, allow_changing_schema=True -> shows schemas
(
{
"account_identifier": "test_account",
"database": "ANALYTICS_DB",
"schema": "PUBLIC",
"auth": {
"auth_type": "user_password",
"username": "test_user",
"password": "test_password",
},
"allow_changing_schema": True,
},
None,
None,
["PUBLIC", "STAGING", "DEV"],
False,
True,
),
# Database not configured -> shows databases
(
{
"account_identifier": "test_account",
"auth": {
"auth_type": "user_password",
"username": "test_user",
"password": "test_password",
},
"allow_changing_database": True,
"allow_changing_schema": True,
},
None,
["ANALYTICS_DB", "SALES_DB"],
None,
True,
True,
),
# Database configured, allow_changing_database=True -> shows databases
(
{
"account_identifier": "test_account",
"database": "ANALYTICS_DB",
"schema": "PUBLIC",
"auth": {
"auth_type": "user_password",
"username": "test_user",
"password": "test_password",
},
"allow_changing_database": True,
"allow_changing_schema": False,
},
None,
["ANALYTICS_DB", "SALES_DB"],
None,
True,
False,
),
# Runtime data provides database -> shows schemas for that database
(
{
"account_identifier": "test_account",
"auth": {
"auth_type": "user_password",
"username": "test_user",
"password": "test_password",
},
"allow_changing_database": True,
"allow_changing_schema": True,
},
{"database": "SALES_DB"},
["ANALYTICS_DB", "SALES_DB"],
["SALES_SCHEMA", "CUSTOMER_SCHEMA"],
True,
True,
),
],
)
def test_get_runtime_schema(
configuration: dict[str, Any],
runtime_data: dict[str, Any] | None,
databases: list[str] | None,
schemas: list[str] | None,
expect_database: bool,
expect_schema: bool,
) -> None:
"""
Test runtime schema generation with various configuration combinations.
The runtime schema should only include fields that the user can change:
- database field if database is not configured or changing is allowed
- schema field if schema is not configured or changing is allowed
"""
# Create configuration
config = SnowflakeConfiguration(**configuration)
# Mock the connection and cursor
mock_cursor = MagicMock()
mock_connection = MagicMock()
mock_connection.cursor.return_value = mock_cursor
# Setup cursor responses
if databases:
# SHOW DATABASES returns (name, name, ...)
mock_cursor.__iter__.return_value = iter(
[(i, db, "", "", "", "", "") for i, db in enumerate(databases)]
)
if schemas:
# SELECT SCHEMA_NAME returns (schema_name,)
mock_cursor.execute.return_value = iter([(schema,) for schema in schemas])
# Mock connect to return our mock connection
with patch(
"superset.semantic_layers.snowflake.semantic_layer.connect"
) as mock_connect:
mock_connect.return_value.__enter__.return_value = mock_connection
# Get the runtime schema
schema = SnowflakeSemanticLayer.get_runtime_schema(config, runtime_data)
# Verify connect was called
mock_connect.assert_called_once()
# Verify schema structure
assert "properties" in schema
# Verify database field presence
if expect_database:
assert "database" in schema["properties"]
# Should have enum with available databases
if databases:
db_enum = schema["properties"]["database"].get("enum", [])
assert set(db_enum) == set(databases)
else:
assert "database" not in schema["properties"]
# Verify schema field presence
if expect_schema:
assert "schema" in schema["properties"]
# Should have enum with available schemas if we have a database
if schemas and (
configuration.get("database")
or (runtime_data and runtime_data.get("database"))
):
schema_enum = schema["properties"]["schema"].get("enum", [])
assert set(schema_enum) == set(schemas)
else:
assert "schema" not in schema["properties"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,281 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# flake8: noqa: E501
from contextlib import nullcontext
from typing import Any
import pytest
from superset.semantic_layers.snowflake import SnowflakeConfiguration
from superset.semantic_layers.snowflake.utils import (
get_connection_parameters,
substitute_parameters,
validate_order_by,
)
@pytest.mark.parametrize(
"query, parameters, expected",
[
# No parameters
("SELECT * FROM table", None, "SELECT * FROM table"),
("SELECT * FROM table", [], "SELECT * FROM table"),
# NULL values
(
"SELECT * FROM table WHERE id = ?",
[None],
"SELECT * FROM table WHERE id = NULL",
),
# Integer values
(
"SELECT * FROM table WHERE id = ?",
[123],
"SELECT * FROM table WHERE id = 123",
),
(
"SELECT * FROM table WHERE id = ? AND status = ?",
[123, 456],
"SELECT * FROM table WHERE id = 123 AND status = 456",
),
# Float values
(
"SELECT * FROM table WHERE price = ?",
[99.99],
"SELECT * FROM table WHERE price = 99.99",
),
(
"SELECT * FROM table WHERE price BETWEEN ? AND ?",
[10.5, 99.99],
"SELECT * FROM table WHERE price BETWEEN 10.5 AND 99.99",
),
# Boolean values
(
"SELECT * FROM table WHERE active = ?",
[True],
"SELECT * FROM table WHERE active = TRUE",
),
(
"SELECT * FROM table WHERE active = ? AND deleted = ?",
[True, False],
"SELECT * FROM table WHERE active = TRUE AND deleted = FALSE",
),
# String values
(
"SELECT * FROM table WHERE name = ?",
["John"],
"SELECT * FROM table WHERE name = 'John'",
),
(
"SELECT * FROM table WHERE name = ? OR name = ?",
["John", "Jane"],
"SELECT * FROM table WHERE name = 'John' OR name = 'Jane'",
),
# String with single quotes (should be escaped)
(
"SELECT * FROM table WHERE name = ?",
["O'Brien"],
"SELECT * FROM table WHERE name = 'O''Brien'",
),
(
"SELECT * FROM table WHERE text = ?",
["It's a test"],
"SELECT * FROM table WHERE text = 'It''s a test'",
),
# Mixed types
(
(
"SELECT * FROM table WHERE name = ? "
"AND age = ? AND active = ? AND salary = ?"
),
["John", 30, True, 50000.5],
(
"SELECT * FROM table WHERE name = 'John' "
"AND age = 30 AND active = TRUE AND salary = 50000.5"
),
),
(
"SELECT * FROM table WHERE col1 = ? AND col2 = ? AND col3 = ?",
[None, "test", 42],
"SELECT * FROM table WHERE col1 = NULL AND col2 = 'test' AND col3 = 42",
),
],
)
def test_substitute_parameters(
query: str,
parameters: list[Any] | None,
expected: str,
) -> None:
"""
Test parameter substitution for various types and combinations.
"""
assert substitute_parameters(query, parameters) == expected
@pytest.mark.parametrize(
"definition, should_raise",
[
# Valid simple cases
("column_name", False),
("COUNT(*)", False),
("SUM(amount)", False),
("table.column", False),
("schema.table.column", False),
# Valid with direction
("column_name ASC", False),
("column_name DESC", False),
("COUNT(*) DESC", False),
("SUM(revenue) ASC", False),
# Valid with NULLS handling
("column_name NULLS FIRST", False),
("column_name NULLS LAST", False),
("column_name ASC NULLS FIRST", False),
("column_name DESC NULLS LAST", False),
("COUNT(*) DESC NULLS FIRST", False),
# Valid complex expressions
("gender ASC, COUNT(*)", False),
("gender ASC, COUNT(*) DESC", False),
("col1 ASC, col2 DESC, col3", False),
("CASE WHEN x > 0 THEN 1 ELSE 0 END", False),
("CAST(column AS INTEGER)", False),
("UPPER(name)", False),
("CONCAT(first_name, ' ', last_name)", False),
# Valid with mixed complexity
("table.column ASC NULLS FIRST, COUNT(*) DESC", False),
("schema.table.col1, func(col2) DESC NULLS LAST", False),
# Invalid - SQL injection attempts with semicolons
("column_name; DROP TABLE users;", True),
("column_name; DELETE FROM data; --", True),
("name; UPDATE users SET admin=1; --", True),
# Invalid - SQL injection with multiple statements
("col1; SELECT * FROM passwords", True),
("col1; INSERT INTO logs VALUES(1)", True),
# Edge cases - incomplete syntax
("column/*", True),
],
)
def test_validate_order_by(definition: str, should_raise: bool) -> None:
"""
Test ORDER BY validation for valid expressions and SQL injection prevention.
"""
context = (
pytest.raises(ValueError, match="Invalid ORDER BY")
if should_raise
else nullcontext()
)
with context:
validate_order_by(definition)
@pytest.mark.parametrize(
"configuration, expected",
[
# Minimal UserPasswordAuth configuration
(
{
"account_identifier": "test_account",
"auth": {
"auth_type": "user_password",
"username": "test_user",
"password": "test_password",
},
"allow_changing_database": True,
"allow_changing_schema": True,
},
{
"account": "test_account",
"application": "Apache Superset",
"paramstyle": "qmark",
"insecure_mode": True,
"user": "test_user",
"password": "test_password",
},
),
# Full UserPasswordAuth configuration
(
{
"account_identifier": "test_account",
"role": "ACCOUNTADMIN",
"warehouse": "COMPUTE_WH",
"database": "TEST_DB",
"schema": "PUBLIC",
"auth": {
"auth_type": "user_password",
"username": "admin",
"password": "secret123",
},
},
{
"account": "test_account",
"application": "Apache Superset",
"paramstyle": "qmark",
"insecure_mode": True,
"role": "ACCOUNTADMIN",
"warehouse": "COMPUTE_WH",
"database": "TEST_DB",
"schema": "PUBLIC",
"user": "admin",
"password": "secret123",
},
),
# UserPasswordAuth with some optional fields
(
{
"account_identifier": "mycompany.us-east-1",
"warehouse": "ETL_WH",
"database": "ANALYTICS",
"auth": {
"auth_type": "user_password",
"username": "analyst",
"password": "p@ssw0rd",
},
"allow_changing_schema": True,
},
{
"account": "mycompany.us-east-1",
"application": "Apache Superset",
"paramstyle": "qmark",
"insecure_mode": True,
"warehouse": "ETL_WH",
"database": "ANALYTICS",
"user": "analyst",
"password": "p@ssw0rd",
},
),
],
)
def test_get_connection_parameters(
configuration: dict[str, Any],
expected: dict[str, Any],
) -> None:
"""
Test connection parameter generation for various configurations.
"""
# Create configuration from params
config = SnowflakeConfiguration(**configuration)
# Get connection parameters
result = get_connection_parameters(config)
# Check that all expected keys are present with correct values
for key, value in expected.items():
assert key in result, f"Expected key '{key}' not found in result"
assert result[key] == value, f"Expected {key}={value}, got {result[key]}"
# Verify no unexpected keys
assert set(result.keys()) == set(expected.keys())

File diff suppressed because it is too large Load Diff