mirror of
https://github.com/apache/superset.git
synced 2026-05-01 14:04:21 +00:00
Compare commits
9 Commits
feat/toolt
...
auth-issue
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d279a36eb4 | ||
|
|
5adae553ae | ||
|
|
8e32387201 | ||
|
|
737b99b5a9 | ||
|
|
957b298ae1 | ||
|
|
f29d82b3b1 | ||
|
|
3f550f166f | ||
|
|
86eb6176d1 | ||
|
|
4244ae87bf |
@@ -707,7 +707,7 @@ protobuf==4.25.8
|
||||
# proto-plus
|
||||
psutil==6.1.0
|
||||
# via apache-superset
|
||||
psycopg2-binary==2.9.9
|
||||
psycopg2-binary==2.9.12
|
||||
# via apache-superset
|
||||
py-key-value-aio==0.4.4
|
||||
# via fastmcp
|
||||
|
||||
@@ -22,7 +22,6 @@ from datetime import datetime
|
||||
from pprint import pformat
|
||||
from typing import Any, NamedTuple, TYPE_CHECKING
|
||||
|
||||
from flask import g
|
||||
from flask_babel import gettext as _
|
||||
from jinja2.exceptions import TemplateError
|
||||
from pandas import DataFrame
|
||||
@@ -38,6 +37,7 @@ from superset.extensions import event_logger
|
||||
from superset.sql.parse import sanitize_clause, transpile_to_dialect
|
||||
from superset.superset_typing import Column, Metric, OrderBy, QueryObjectDict
|
||||
from superset.utils import json, pandas_postprocessing
|
||||
from superset.utils.cache_keys import add_impersonation_cache_key_if_needed
|
||||
from superset.utils.core import (
|
||||
DTTM_ALIAS,
|
||||
find_duplicates,
|
||||
@@ -479,24 +479,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||
# or if the CACHE_QUERY_BY_USER flag is on or per_user_caching is enabled on
|
||||
# the database
|
||||
try:
|
||||
database = self.datasource.database # type: ignore
|
||||
extra = json.loads(database.extra or "{}")
|
||||
if (
|
||||
(
|
||||
feature_flag_manager.is_feature_enabled("CACHE_IMPERSONATION")
|
||||
and database.impersonate_user
|
||||
)
|
||||
or feature_flag_manager.is_feature_enabled("CACHE_QUERY_BY_USER")
|
||||
or extra.get("per_user_caching", False)
|
||||
):
|
||||
if key := database.db_engine_spec.get_impersonation_key(
|
||||
getattr(g, "user", None)
|
||||
):
|
||||
logger.debug(
|
||||
"Adding impersonation key to QueryObject cache dict: %s", key
|
||||
)
|
||||
|
||||
cache_dict["impersonation_key"] = key
|
||||
add_impersonation_cache_key_if_needed(self.datasource.database, cache_dict) # type: ignore
|
||||
except AttributeError:
|
||||
# datasource or database do not exist
|
||||
pass
|
||||
|
||||
@@ -590,7 +590,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
# Driver-specific params to be included in the `get_oauth2_token` request body
|
||||
oauth2_additional_token_request_params: dict[str, Any] = {}
|
||||
# Driver-specific exception that should be mapped to OAuth2RedirectError
|
||||
oauth2_exception = OAuth2RedirectError
|
||||
oauth2_exception: type[Exception] | tuple[type[Exception], ...] = (
|
||||
OAuth2RedirectError
|
||||
)
|
||||
|
||||
# Does the query id related to the connection?
|
||||
# The default value is True, which means that the query id is determined when
|
||||
|
||||
@@ -31,6 +31,7 @@ from marshmallow import fields, Schema
|
||||
from marshmallow.exceptions import ValidationError
|
||||
from requests import Session
|
||||
from shillelagh.adapters.api.gsheets.lib import SCOPES
|
||||
from shillelagh.exceptions import UnauthenticatedError
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
@@ -40,7 +41,7 @@ from superset.databases.schemas import encrypted_field_properties, EncryptedStri
|
||||
from superset.db_engine_specs.base import DatabaseCategory
|
||||
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.exceptions import OAuth2TokenRefreshError, SupersetException
|
||||
from superset.utils import json
|
||||
from superset.utils.oauth2 import get_oauth2_access_token
|
||||
|
||||
@@ -151,6 +152,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||
"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
)
|
||||
oauth2_token_request_uri = "https://oauth2.googleapis.com/token" # noqa: S105
|
||||
oauth2_exception = (UnauthenticatedError, OAuth2TokenRefreshError)
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_authorization_uri(
|
||||
|
||||
@@ -62,6 +62,7 @@ Dataset Management:
|
||||
- list_datasets: List datasets with advanced filters (1-based pagination)
|
||||
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
|
||||
- create_virtual_dataset: Save a SQL query as a virtual dataset for charting
|
||||
- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart
|
||||
|
||||
Chart Management:
|
||||
- list_charts: List charts with advanced filters (1-based pagination)
|
||||
@@ -164,6 +165,17 @@ Use created_by_me for authorship, owned_by_me for edit ownership, or both
|
||||
together for the union. All flags can be combined with 'filters' but not
|
||||
with 'search'.
|
||||
|
||||
To query a dataset's semantic layer (metrics, dimensions):
|
||||
1. list_datasets(request={{}}) -> find a dataset
|
||||
2. get_dataset_info(request={{"identifier": <id>}}) -> examine columns AND metrics
|
||||
3. query_dataset(request={{
|
||||
"dataset_id": <id>,
|
||||
"metrics": ["count", "avg_revenue"],
|
||||
"columns": ["category"],
|
||||
"time_range": "Last 7 days",
|
||||
"row_limit": 100
|
||||
}}) -> returns tabular data using saved metrics and dimensions
|
||||
|
||||
To explore data with SQL:
|
||||
1. list_datasets(request={{}}) -> find a dataset and note its database_id
|
||||
2. execute_sql(request={{"database_id": <id>, "sql": "SELECT ..."}})
|
||||
@@ -520,6 +532,7 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||
create_virtual_dataset,
|
||||
get_dataset_info,
|
||||
list_datasets,
|
||||
query_dataset,
|
||||
)
|
||||
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
|
||||
generate_explore_link,
|
||||
|
||||
@@ -632,6 +632,15 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
|
||||
@functools.wraps(tool_func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
with _get_app_context_manager():
|
||||
# Clear any stale thread-local SQLAlchemy session before user lookup.
|
||||
# Thread pool workers reuse threads across requests; db.session is
|
||||
# scoped by thread (not ContextVar), so a prior request's session may
|
||||
# still be bound to a different tenant's DB engine. Removing it here
|
||||
# ensures the next DB access creates a fresh session bound to the
|
||||
# correct engine for the current request.
|
||||
from superset.extensions import db
|
||||
|
||||
db.session.remove()
|
||||
user = _setup_user_context()
|
||||
|
||||
# No Flask context - this is a FastMCP internal operation
|
||||
|
||||
@@ -70,6 +70,8 @@ SORTABLE_CHART_COLUMNS = [
|
||||
"created_on",
|
||||
]
|
||||
|
||||
_DEFAULT_LIST_CHARTS_REQUEST = ListChartsRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
@@ -81,7 +83,8 @@ SORTABLE_CHART_COLUMNS = [
|
||||
),
|
||||
)
|
||||
async def list_charts(
|
||||
request: ListChartsRequest, ctx: Context
|
||||
request: ListChartsRequest | None = None,
|
||||
ctx: Context = None,
|
||||
) -> ChartList | ChartError:
|
||||
"""List charts with filtering and search.
|
||||
|
||||
@@ -91,6 +94,7 @@ async def list_charts(
|
||||
Sortable columns for order_column: id, slice_name, viz_type, description,
|
||||
changed_on, created_on
|
||||
"""
|
||||
request = request or _DEFAULT_LIST_CHARTS_REQUEST.model_copy(deep=True)
|
||||
await ctx.info(
|
||||
"Listing charts: page=%s, page_size=%s, search=%s"
|
||||
% (
|
||||
|
||||
@@ -65,6 +65,8 @@ SORTABLE_DASHBOARD_COLUMNS = [
|
||||
"created_on",
|
||||
]
|
||||
|
||||
_DEFAULT_LIST_DASHBOARDS_REQUEST = ListDashboardsRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
@@ -76,7 +78,8 @@ SORTABLE_DASHBOARD_COLUMNS = [
|
||||
),
|
||||
)
|
||||
async def list_dashboards(
|
||||
request: ListDashboardsRequest, ctx: Context
|
||||
request: ListDashboardsRequest | None = None,
|
||||
ctx: Context = None,
|
||||
) -> DashboardList:
|
||||
"""List dashboards with filtering and search. Returns dashboard metadata
|
||||
including title, slug, URL, and last modified time. Use select_columns to
|
||||
@@ -85,6 +88,7 @@ async def list_dashboards(
|
||||
Sortable columns for order_column: id, dashboard_title, slug, published,
|
||||
changed_on, created_on
|
||||
"""
|
||||
request = request or _DEFAULT_LIST_DASHBOARDS_REQUEST.model_copy(deep=True)
|
||||
await ctx.info(
|
||||
"Listing dashboards: page=%s, page_size=%s, search=%s"
|
||||
% (
|
||||
|
||||
@@ -36,10 +36,13 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
|
||||
from superset.mcp_service.common.cache_schemas import (
|
||||
CacheStatus,
|
||||
CreatedByMeMixin,
|
||||
MetadataCacheControl,
|
||||
OwnedByMeMixin,
|
||||
QueryCacheControl,
|
||||
)
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from superset.mcp_service.privacy import filter_user_directory_fields
|
||||
@@ -393,6 +396,146 @@ class CreateVirtualDatasetResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
VALID_FILTER_OPS = Literal[
|
||||
"==",
|
||||
"!=",
|
||||
">",
|
||||
"<",
|
||||
">=",
|
||||
"<=",
|
||||
"LIKE",
|
||||
"NOT LIKE",
|
||||
"ILIKE",
|
||||
"NOT ILIKE",
|
||||
"IN",
|
||||
"NOT IN",
|
||||
"IS NULL",
|
||||
"IS NOT NULL",
|
||||
"IS TRUE",
|
||||
"IS FALSE",
|
||||
"TEMPORAL_RANGE",
|
||||
]
|
||||
|
||||
|
||||
class QueryDatasetFilter(BaseModel):
|
||||
"""A single filter condition for dataset queries."""
|
||||
|
||||
col: str = Field(..., description="Column name to filter on")
|
||||
op: VALID_FILTER_OPS = Field(
|
||||
...,
|
||||
description=(
|
||||
'Filter operator. Use "==" for equals, "!=" for not equals, '
|
||||
'"IN" / "NOT IN" for membership, "IS NULL" / "IS NOT NULL", '
|
||||
'"LIKE" for pattern matching, "TEMPORAL_RANGE" for time filters.'
|
||||
),
|
||||
)
|
||||
val: Any = Field(
|
||||
default=None,
|
||||
description="Filter value (omit for IS NULL/IS NOT NULL)",
|
||||
)
|
||||
|
||||
|
||||
class QueryDatasetRequest(QueryCacheControl):
|
||||
"""Request schema for query_dataset tool."""
|
||||
|
||||
dataset_id: int | str = Field(
|
||||
...,
|
||||
description="Dataset identifier — numeric ID or UUID string.",
|
||||
)
|
||||
metrics: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Saved metric names to compute (e.g. ['count', 'avg_revenue']). "
|
||||
"Use get_dataset_info to discover available metrics."
|
||||
),
|
||||
)
|
||||
columns: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Column/dimension names for GROUP BY or SELECT "
|
||||
"(e.g. ['category', 'region']). "
|
||||
"Use get_dataset_info to discover available columns."
|
||||
),
|
||||
)
|
||||
filters: List[QueryDatasetFilter] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
'Filter conditions (e.g. [{"col": "status", "op": "==", "val": "active"}]).'
|
||||
),
|
||||
)
|
||||
time_range: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Time range filter (e.g. 'Last 7 days', 'Last month', "
|
||||
"'2024-01-01 : 2024-12-31'). Requires a temporal column "
|
||||
"on the dataset."
|
||||
),
|
||||
)
|
||||
time_column: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Temporal column to apply time_range to. "
|
||||
"Defaults to the dataset's main datetime column."
|
||||
),
|
||||
)
|
||||
order_by: List[str] | None = Field(
|
||||
default=None,
|
||||
description="Column or metric names to sort results by.",
|
||||
)
|
||||
order_desc: bool = Field(
|
||||
default=True,
|
||||
description="Sort descending (True) or ascending (False).",
|
||||
)
|
||||
row_limit: int = Field(
|
||||
default=1000,
|
||||
ge=1,
|
||||
le=50000,
|
||||
description="Maximum number of rows to return (default 1000, max 50000).",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_metrics_or_columns(self) -> "QueryDatasetRequest":
|
||||
"""At least one of metrics or columns must be provided."""
|
||||
if not self.metrics and not self.columns:
|
||||
raise ValueError(
|
||||
"At least one of 'metrics' or 'columns' must be provided. "
|
||||
"Use get_dataset_info to discover available metrics and columns."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class QueryDatasetResponse(BaseModel):
|
||||
"""Response schema for query_dataset tool."""
|
||||
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
dataset_id: int = Field(..., description="Dataset ID")
|
||||
dataset_name: str = Field(..., description="Dataset name")
|
||||
columns: List[DataColumn] = Field(
|
||||
default_factory=list, description="Column metadata for returned data"
|
||||
)
|
||||
data: List[Dict[str, Any]] = Field(
|
||||
default_factory=list, description="Query result rows"
|
||||
)
|
||||
row_count: int = Field(0, description="Number of rows returned")
|
||||
total_rows: int | None = Field(
|
||||
None, description="Total row count from the query engine"
|
||||
)
|
||||
summary: str = Field("", description="Human-readable summary of the results")
|
||||
performance: PerformanceMetadata | None = Field(
|
||||
None, description="Query performance metadata"
|
||||
)
|
||||
cache_status: CacheStatus | None = Field(
|
||||
None, description="Cache hit/miss information"
|
||||
)
|
||||
applied_filters: List[QueryDatasetFilter] = Field(
|
||||
default_factory=list, description="Filters that were applied to the query"
|
||||
)
|
||||
warnings: List[str] = Field(
|
||||
default_factory=list, description="Any warnings encountered during execution"
|
||||
)
|
||||
|
||||
|
||||
def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None:
|
||||
"""Parse a field that may be stored as a JSON string into a dict."""
|
||||
value = getattr(obj, field_name, None)
|
||||
|
||||
@@ -18,9 +18,11 @@
|
||||
from .create_virtual_dataset import create_virtual_dataset
|
||||
from .get_dataset_info import get_dataset_info
|
||||
from .list_datasets import list_datasets
|
||||
from .query_dataset import query_dataset
|
||||
|
||||
__all__ = [
|
||||
"create_virtual_dataset",
|
||||
"list_datasets",
|
||||
"get_dataset_info",
|
||||
"query_dataset",
|
||||
]
|
||||
|
||||
489
superset/mcp_service/dataset/tool/query_dataset.py
Normal file
489
superset/mcp_service/dataset/tool/query_dataset.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: query_dataset
|
||||
|
||||
Query a dataset using its semantic layer (saved metrics, calculated columns,
|
||||
dimensions) without requiring a saved chart.
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import joinedload, subqueryload
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
|
||||
from superset.mcp_service.dataset.schemas import (
|
||||
DatasetError,
|
||||
QueryDatasetFilter,
|
||||
QueryDatasetRequest,
|
||||
QueryDatasetResponse,
|
||||
)
|
||||
from superset.mcp_service.privacy import (
|
||||
DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
requires_data_model_metadata_access,
|
||||
user_can_view_data_model_metadata,
|
||||
)
|
||||
from superset.mcp_service.utils import _is_uuid
|
||||
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
|
||||
from superset.mcp_service.utils.oauth2_utils import build_oauth2_redirect_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_dataset(identifier: int | str, eager_options: list[Any]) -> Any | None:
|
||||
"""Resolve a dataset by int ID or UUID string.
|
||||
|
||||
Replicates the identifier resolution logic from ModelGetInfoCore._find_object().
|
||||
"""
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
opts = eager_options or None
|
||||
|
||||
if isinstance(identifier, int):
|
||||
return DatasetDAO.find_by_id(identifier, query_options=opts)
|
||||
|
||||
# Try parsing as int
|
||||
try:
|
||||
id_val = int(identifier)
|
||||
return DatasetDAO.find_by_id(id_val, query_options=opts)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Try UUID
|
||||
if _is_uuid(str(identifier)):
|
||||
return DatasetDAO.find_by_id(identifier, id_column="uuid", query_options=opts)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_names(
|
||||
requested: list[str],
|
||||
valid: set[str],
|
||||
kind: str,
|
||||
) -> list[str]:
|
||||
"""Return list of error messages for names not found in *valid*.
|
||||
|
||||
Includes close-match suggestions when available.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
for name in requested:
|
||||
if name not in valid:
|
||||
suggestions = difflib.get_close_matches(name, valid, n=3, cutoff=0.6)
|
||||
msg = f"Unknown {kind}: '{name}'"
|
||||
if suggestions:
|
||||
msg += f". Did you mean: {', '.join(suggestions)}?"
|
||||
errors.append(msg)
|
||||
return errors
|
||||
|
||||
|
||||
@requires_data_model_metadata_access
|
||||
@tool(
|
||||
tags=["data"],
|
||||
class_permission_name="Dataset",
|
||||
annotations=ToolAnnotations(
|
||||
title="Query dataset",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def query_dataset( # noqa: C901
|
||||
request: QueryDatasetRequest, ctx: Context
|
||||
) -> QueryDatasetResponse | DatasetError:
|
||||
"""Query a dataset using its semantic layer (saved metrics, dimensions, filters).
|
||||
|
||||
Returns tabular data without requiring a saved chart. Use this when you want
|
||||
to compute saved metrics, group by dimensions, or apply filters directly
|
||||
against a dataset's curated semantic layer.
|
||||
|
||||
Workflow:
|
||||
1. list_datasets -> find a dataset
|
||||
2. get_dataset_info -> discover available columns and metrics
|
||||
3. query_dataset -> query using metric names and column names
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"dataset_id": 123,
|
||||
"metrics": ["count", "avg_revenue"],
|
||||
"columns": ["product_category"],
|
||||
"time_range": "Last 7 days",
|
||||
"row_limit": 100
|
||||
}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Starting dataset query: dataset_id=%s, metrics=%s, columns=%s, "
|
||||
"row_limit=%s"
|
||||
% (
|
||||
request.dataset_id,
|
||||
request.metrics,
|
||||
request.columns,
|
||||
request.row_limit,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.commands.chart.data.get_data_command import ChartDataCommand
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Check data-model metadata access BEFORE the dataset lookup.
|
||||
# Doing this first prevents leaking dataset existence — restricted
|
||||
# users always receive DataModelMetadataRestricted, never NotFound.
|
||||
# The decorator hides this tool from search; this check enforces
|
||||
# direct calls that bypass tool discovery.
|
||||
# ------------------------------------------------------------------
|
||||
if not user_can_view_data_model_metadata():
|
||||
await ctx.warning("Dataset metadata access blocked by privacy controls")
|
||||
return DatasetError.create(
|
||||
error=(
|
||||
"You don't have permission to access dataset details for your role."
|
||||
),
|
||||
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Resolve dataset
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(1, 5, "Looking up dataset")
|
||||
eager_options = [
|
||||
subqueryload(SqlaTable.columns),
|
||||
subqueryload(SqlaTable.metrics),
|
||||
joinedload(SqlaTable.database),
|
||||
]
|
||||
|
||||
with event_logger.log_context(action="mcp.query_dataset.lookup"):
|
||||
dataset = _resolve_dataset(request.dataset_id, eager_options)
|
||||
|
||||
if dataset is None:
|
||||
await ctx.error("Dataset not found: identifier=%s" % (request.dataset_id,))
|
||||
return DatasetError.create(
|
||||
error=f"No dataset found with identifier: {request.dataset_id}",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
dataset_name = getattr(dataset, "table_name", None) or f"Dataset {dataset.id}"
|
||||
await ctx.info(
|
||||
"Dataset found: id=%s, name=%s, columns=%s, metrics=%s"
|
||||
% (
|
||||
dataset.id,
|
||||
dataset_name,
|
||||
len(dataset.columns),
|
||||
len(dataset.metrics),
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Validate requested columns and metrics
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(2, 5, "Validating columns and metrics")
|
||||
valid_columns = {c.column_name for c in dataset.columns}
|
||||
valid_metrics = {m.metric_name for m in dataset.metrics}
|
||||
|
||||
validation_errors: list[str] = []
|
||||
validation_errors.extend(
|
||||
_validate_names(request.columns, valid_columns, "column")
|
||||
)
|
||||
validation_errors.extend(
|
||||
_validate_names(request.metrics, valid_metrics, "metric")
|
||||
)
|
||||
# Validate filter column names against dataset columns
|
||||
filter_cols = [f.col for f in request.filters]
|
||||
validation_errors.extend(
|
||||
_validate_names(filter_cols, valid_columns, "filter column")
|
||||
)
|
||||
# Validate order_by names against columns + metrics
|
||||
if request.order_by:
|
||||
valid_orderby = valid_columns | valid_metrics
|
||||
validation_errors.extend(
|
||||
_validate_names(request.order_by, valid_orderby, "order_by")
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
error_msg = "; ".join(validation_errors)
|
||||
await ctx.error("Validation failed: %s" % (error_msg,))
|
||||
return DatasetError.create(
|
||||
error=error_msg,
|
||||
error_type="ValidationError",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3: Build filters and time range
|
||||
# ------------------------------------------------------------------
|
||||
warnings: list[str] = []
|
||||
query_filters: list[dict[str, Any]] = [
|
||||
{"col": f.col, "op": f.op, "val": f.val} for f in request.filters
|
||||
]
|
||||
# Track all applied filters (including synthesized ones) for the response.
|
||||
effective_filters: list[QueryDatasetFilter] = list(request.filters)
|
||||
granularity: str | None = None
|
||||
|
||||
if request.time_range:
|
||||
temporal_col = request.time_column or getattr(
|
||||
dataset, "main_dttm_col", None
|
||||
)
|
||||
if not temporal_col:
|
||||
await ctx.error("time_range provided but no temporal column available")
|
||||
return DatasetError.create(
|
||||
error=(
|
||||
"time_range was provided but no temporal column is available. "
|
||||
"Either set time_column explicitly or ensure the dataset has "
|
||||
"a main datetime column configured."
|
||||
),
|
||||
error_type="ValidationError",
|
||||
)
|
||||
# Validate that the temporal column actually exists on the dataset
|
||||
if temporal_col not in valid_columns:
|
||||
await ctx.error("time_column '%s' not found on dataset" % temporal_col)
|
||||
return DatasetError.create(
|
||||
error=(
|
||||
f"time_column '{temporal_col}' does not exist on this dataset."
|
||||
),
|
||||
error_type="ValidationError",
|
||||
)
|
||||
# Warn if the chosen temporal column isn't marked as datetime
|
||||
dttm_cols = {c.column_name for c in dataset.columns if c.is_dttm}
|
||||
if temporal_col not in dttm_cols:
|
||||
warnings.append(
|
||||
f"Column '{temporal_col}' is not marked as a datetime "
|
||||
f"column on this dataset. Time filtering may not work "
|
||||
f"as expected."
|
||||
)
|
||||
|
||||
query_filters.append(
|
||||
{
|
||||
"col": temporal_col,
|
||||
"op": "TEMPORAL_RANGE",
|
||||
"val": request.time_range,
|
||||
}
|
||||
)
|
||||
effective_filters.append(
|
||||
QueryDatasetFilter(
|
||||
col=temporal_col,
|
||||
op="TEMPORAL_RANGE",
|
||||
val=request.time_range,
|
||||
)
|
||||
)
|
||||
granularity = temporal_col
|
||||
await ctx.debug(
|
||||
"Time filter: column=%s, range=%s" % (temporal_col, request.time_range)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 4: Build query dict
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(3, 5, "Building query")
|
||||
query_dict: dict[str, Any] = {
|
||||
"filters": query_filters,
|
||||
"columns": request.columns,
|
||||
"metrics": request.metrics,
|
||||
"row_limit": request.row_limit,
|
||||
"order_desc": request.order_desc,
|
||||
}
|
||||
if granularity:
|
||||
query_dict["granularity"] = granularity
|
||||
if request.order_by:
|
||||
# OrderBy = tuple[Metric | Column, bool] where bool is ascending
|
||||
query_dict["orderby"] = [
|
||||
(col, not request.order_desc) for col in request.order_by
|
||||
]
|
||||
|
||||
await ctx.debug("Query dict keys: %s" % (sorted(query_dict.keys()),))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 5: Create QueryContext and execute
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(4, 5, "Executing query")
|
||||
start_time = time.time()
|
||||
|
||||
with event_logger.log_context(action="mcp.query_dataset.execute"):
|
||||
factory = QueryContextFactory()
|
||||
# datasource_type is "table" because this tool queries SqlaTable
|
||||
# datasets (Superset's built-in semantic layer). External semantic
|
||||
# layers (dbt, Snowflake Cortex, etc.) use "semantic_view" and have
|
||||
# a different query path — see SemanticView + mapper.py.
|
||||
query_context = factory.create(
|
||||
datasource={"id": dataset.id, "type": "table"},
|
||||
queries=[query_dict],
|
||||
form_data={},
|
||||
force=not request.use_cache or request.force_refresh,
|
||||
custom_cache_timeout=request.cache_timeout,
|
||||
)
|
||||
|
||||
command = ChartDataCommand(query_context)
|
||||
command.validate()
|
||||
result = command.run()
|
||||
|
||||
query_duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if not result or "queries" not in result or len(result["queries"]) == 0:
|
||||
await ctx.warning("Query returned no results for dataset %s" % dataset.id)
|
||||
return DatasetError.create(
|
||||
error="Query returned no results.",
|
||||
error_type="EmptyQuery",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 6: Format response
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(5, 5, "Formatting results")
|
||||
query_result = result["queries"][0]
|
||||
data = query_result.get("data", [])
|
||||
raw_columns = query_result.get("colnames", [])
|
||||
|
||||
if not data:
|
||||
return QueryDatasetResponse(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset_name,
|
||||
columns=[],
|
||||
data=[],
|
||||
row_count=0,
|
||||
total_rows=0,
|
||||
summary=f"Query on '{dataset_name}' returned no data.",
|
||||
performance=PerformanceMetadata(
|
||||
query_duration_ms=query_duration_ms,
|
||||
cache_status="no_data",
|
||||
),
|
||||
cache_status=get_cache_status_from_result(
|
||||
query_result, force_refresh=request.force_refresh
|
||||
),
|
||||
applied_filters=effective_filters,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
# Build column metadata in a single pass per column.
|
||||
# Cap stats computation at STATS_SAMPLE rows to avoid O(rows*cols)
|
||||
# overhead on large result sets (row_limit allows up to 50k).
|
||||
stats_sample_size = 5000
|
||||
stats_rows = data[:stats_sample_size]
|
||||
|
||||
columns_meta: list[DataColumn] = []
|
||||
for col_name in raw_columns:
|
||||
sample_values = [
|
||||
row.get(col_name) for row in data[:3] if row.get(col_name) is not None
|
||||
]
|
||||
data_type = "string"
|
||||
if sample_values:
|
||||
if all(isinstance(v, bool) for v in sample_values):
|
||||
data_type = "boolean"
|
||||
elif all(isinstance(v, (int, float)) for v in sample_values):
|
||||
data_type = "numeric"
|
||||
|
||||
# Compute null_count and unique non-null values in one pass
|
||||
null_count = 0
|
||||
unique_vals: set[str] = set()
|
||||
for row in stats_rows:
|
||||
val = row.get(col_name)
|
||||
if val is None:
|
||||
null_count += 1
|
||||
else:
|
||||
unique_vals.add(str(val))
|
||||
|
||||
columns_meta.append(
|
||||
DataColumn(
|
||||
name=col_name,
|
||||
display_name=col_name.replace("_", " ").title(),
|
||||
data_type=data_type,
|
||||
sample_values=sample_values[:3],
|
||||
null_count=null_count,
|
||||
unique_count=len(unique_vals),
|
||||
)
|
||||
)
|
||||
|
||||
cache_status = get_cache_status_from_result(
|
||||
query_result, force_refresh=request.force_refresh
|
||||
)
|
||||
|
||||
cache_label = "cached" if cache_status and cache_status.cache_hit else "fresh"
|
||||
summary = (
|
||||
f"Dataset '{dataset_name}': {len(data)} rows, "
|
||||
f"{len(raw_columns)} columns ({cache_label})."
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Query complete: rows=%s, columns=%s, duration=%sms"
|
||||
% (len(data), len(raw_columns), query_duration_ms)
|
||||
)
|
||||
|
||||
return QueryDatasetResponse(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset_name,
|
||||
columns=columns_meta,
|
||||
data=data,
|
||||
row_count=len(data),
|
||||
total_rows=query_result.get("rowcount"),
|
||||
summary=summary,
|
||||
performance=PerformanceMetadata(
|
||||
query_duration_ms=query_duration_ms,
|
||||
cache_status=cache_label,
|
||||
),
|
||||
cache_status=cache_status,
|
||||
applied_filters=effective_filters,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
except OAuth2RedirectError as exc:
|
||||
redirect_msg = build_oauth2_redirect_message(exc)
|
||||
await ctx.error("OAuth2 redirect required: %s" % (redirect_msg,))
|
||||
return DatasetError.create(
|
||||
error=redirect_msg,
|
||||
error_type="OAuth2Redirect",
|
||||
)
|
||||
|
||||
except OAuth2Error as exc:
|
||||
await ctx.error("OAuth2 error: %s" % (str(exc),))
|
||||
return DatasetError.create(
|
||||
error=f"OAuth2 authentication error: {exc}",
|
||||
error_type="OAuth2Error",
|
||||
)
|
||||
|
||||
except (CommandException, SupersetException) as exc:
|
||||
await ctx.error("Query failed: %s" % (str(exc),))
|
||||
return DatasetError.create(
|
||||
error=f"Query execution failed: {exc}",
|
||||
error_type="QueryError",
|
||||
)
|
||||
|
||||
except SQLAlchemyError as exc:
|
||||
await ctx.error("Database error: %s" % (str(exc),))
|
||||
return DatasetError.create(
|
||||
error=f"Database error: {exc}",
|
||||
error_type="DatabaseError",
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Unexpected error while querying dataset: %s: %s",
|
||||
type(exc).__name__,
|
||||
str(exc),
|
||||
)
|
||||
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
|
||||
return DatasetError.create(
|
||||
error="An unexpected error occurred while querying the dataset.",
|
||||
error_type="UnexpectedError",
|
||||
)
|
||||
54
superset/utils/cache_keys.py
Normal file
54
superset/utils/cache_keys.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from flask import g
|
||||
|
||||
from superset import feature_flag_manager
|
||||
from superset.utils.json import loads as json_loads
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_impersonation_cache_key_if_needed(
|
||||
database: Database,
|
||||
cache_dict: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Add a per-user cache-key when the DB connection is configured for
|
||||
per-user caching, no-op otherwise.
|
||||
"""
|
||||
extra = json_loads(database.extra or "{}")
|
||||
if (
|
||||
(
|
||||
feature_flag_manager.is_feature_enabled("CACHE_IMPERSONATION")
|
||||
and database.impersonate_user
|
||||
)
|
||||
or feature_flag_manager.is_feature_enabled("CACHE_QUERY_BY_USER")
|
||||
or extra.get("per_user_caching", False)
|
||||
):
|
||||
if key := database.db_engine_spec.get_impersonation_key(
|
||||
getattr(g, "user", None)
|
||||
):
|
||||
logger.debug("Adding impersonation key to cache dict: %s", key)
|
||||
cache_dict["impersonation_key"] = key
|
||||
@@ -65,6 +65,7 @@ from superset.superset_typing import (
|
||||
)
|
||||
from superset.utils import core as utils, csv, json
|
||||
from superset.utils.cache import set_and_log_cache
|
||||
from superset.utils.cache_keys import add_impersonation_cache_key_if_needed
|
||||
from superset.utils.core import (
|
||||
apply_max_row_limit,
|
||||
DateColumn,
|
||||
@@ -472,6 +473,16 @@ class BaseViz: # pylint: disable=too-many-public-methods
|
||||
cache_dict["extra_cache_keys"] = self.datasource.get_extra_cache_keys(query_obj)
|
||||
cache_dict["rls"] = security_manager.get_rls_cache_key(self.datasource)
|
||||
cache_dict["changed_on"] = self.datasource.changed_on
|
||||
|
||||
# Add an impersonation key to cache if impersonation is enabled on the db
|
||||
# or if the CACHE_QUERY_BY_USER flag is on or per_user_caching is enabled on
|
||||
# the database
|
||||
try:
|
||||
add_impersonation_cache_key_if_needed(self.datasource.database, cache_dict)
|
||||
except AttributeError:
|
||||
# datasource or database do not exist
|
||||
pass
|
||||
|
||||
json_data = self.json_dumps(cache_dict, sort_keys=True)
|
||||
return hash_from_str(json_data)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import pandas as pd
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from requests.exceptions import HTTPError
|
||||
from shillelagh.exceptions import UnauthenticatedError
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
@@ -789,6 +790,36 @@ def test_needs_oauth2_with_other_error(mocker: MockerFixture) -> None:
|
||||
assert GSheetsEngineSpec.needs_oauth2(ex) is False
|
||||
|
||||
|
||||
def test_needs_oauth2_with_shillelagh_unauthenticated_error(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that needs_oauth2 returns True when UnauthenticatedError is raised.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
ex = UnauthenticatedError("Token has been revoked")
|
||||
assert GSheetsEngineSpec.needs_oauth2(ex) is True
|
||||
|
||||
|
||||
def test_needs_oauth2_with_unrelated_exception_type(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that an unrelated exception type (with no matching message) returns
|
||||
False.
|
||||
"""
|
||||
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.gsheets.g")
|
||||
g.user = mocker.MagicMock()
|
||||
|
||||
assert GSheetsEngineSpec.needs_oauth2(ValueError("unrelated")) is False
|
||||
|
||||
|
||||
def test_get_oauth2_fresh_token_success(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config: OAuth2ClientConfig,
|
||||
|
||||
@@ -320,66 +320,13 @@ class TestChartDataModelMetadataPrivacy:
|
||||
assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE
|
||||
|
||||
|
||||
class TestListChartsCreatedByMe:
|
||||
"""Tests for the created_by_me flag on ListChartsRequest."""
|
||||
|
||||
def test_created_by_me_default_is_false(self):
|
||||
request = ListChartsRequest()
|
||||
assert request.created_by_me is False
|
||||
|
||||
def test_created_by_me_true_accepted(self):
|
||||
request = ListChartsRequest(created_by_me=True)
|
||||
assert request.created_by_me is True
|
||||
|
||||
def test_created_by_me_combined_with_filters(self):
|
||||
request = ListChartsRequest(
|
||||
created_by_me=True,
|
||||
filters=[ChartFilter(col="slice_name", opr="sw", value="My")],
|
||||
)
|
||||
assert request.created_by_me is True
|
||||
assert len(request.filters) == 1
|
||||
|
||||
def test_created_by_me_with_search_raises(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="created_by_me"):
|
||||
ListChartsRequest(created_by_me=True, search="My charts")
|
||||
|
||||
def test_chart_filter_rejects_created_by_fk(self):
|
||||
"""created_by_fk is not a public filter column; use created_by_me instead."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ChartFilter(col="created_by_fk", opr="eq", value=1)
|
||||
|
||||
|
||||
class TestListChartsOwnedByMe:
|
||||
"""Tests for the owned_by_me flag on ListChartsRequest."""
|
||||
|
||||
def test_owned_by_me_default_is_false(self):
|
||||
request = ListChartsRequest()
|
||||
assert request.owned_by_me is False
|
||||
|
||||
def test_owned_by_me_true_accepted(self):
|
||||
request = ListChartsRequest(owned_by_me=True)
|
||||
assert request.owned_by_me is True
|
||||
|
||||
def test_owned_by_me_combined_with_filters(self):
|
||||
request = ListChartsRequest(
|
||||
owned_by_me=True,
|
||||
filters=[ChartFilter(col="slice_name", opr="sw", value="My")],
|
||||
)
|
||||
assert request.owned_by_me is True
|
||||
assert len(request.filters) == 1
|
||||
|
||||
def test_owned_by_me_with_search_raises(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="owned_by_me"):
|
||||
ListChartsRequest(owned_by_me=True, search="My charts")
|
||||
|
||||
def test_owned_by_me_and_created_by_me_allowed(self):
|
||||
"""Both flags together are valid (OR logic — creator or owner)."""
|
||||
request = ListChartsRequest(owned_by_me=True, created_by_me=True)
|
||||
assert request.owned_by_me is True
|
||||
assert request.created_by_me is True
|
||||
@patch("superset.daos.chart.ChartDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_charts_no_arguments(mock_list, mcp_server):
|
||||
"""Regression test: list_charts must accept zero arguments without raising
|
||||
pydantic_core.ValidationError: Missing required argument: request."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_charts", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "charts" in data
|
||||
|
||||
@@ -30,7 +30,6 @@ from flask import g
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.dashboard.schemas import (
|
||||
DashboardFilter,
|
||||
ListDashboardsRequest,
|
||||
)
|
||||
from superset.mcp_service.dashboard.tool.get_dashboard_info import (
|
||||
@@ -1355,66 +1354,13 @@ class TestDashboardSortableColumns:
|
||||
assert col in list_dashboards.__doc__
|
||||
|
||||
|
||||
class TestListDashboardsCreatedByMe:
|
||||
"""Tests for the created_by_me flag on ListDashboardsRequest."""
|
||||
|
||||
def test_created_by_me_default_is_false(self):
|
||||
request = ListDashboardsRequest()
|
||||
assert request.created_by_me is False
|
||||
|
||||
def test_created_by_me_true_accepted(self):
|
||||
request = ListDashboardsRequest(created_by_me=True)
|
||||
assert request.created_by_me is True
|
||||
|
||||
def test_created_by_me_combined_with_filters(self):
|
||||
request = ListDashboardsRequest(
|
||||
created_by_me=True,
|
||||
filters=[DashboardFilter(col="published", opr="eq", value=True)],
|
||||
)
|
||||
assert request.created_by_me is True
|
||||
assert len(request.filters) == 1
|
||||
|
||||
def test_created_by_me_with_search_raises(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="created_by_me"):
|
||||
ListDashboardsRequest(created_by_me=True, search="My dashboards")
|
||||
|
||||
def test_dashboard_filter_rejects_created_by_fk(self):
|
||||
"""created_by_fk is not a public filter column; use created_by_me instead."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
DashboardFilter(col="created_by_fk", opr="eq", value=1)
|
||||
|
||||
|
||||
class TestListDashboardsOwnedByMe:
|
||||
"""Tests for the owned_by_me flag on ListDashboardsRequest."""
|
||||
|
||||
def test_owned_by_me_default_is_false(self):
|
||||
request = ListDashboardsRequest()
|
||||
assert request.owned_by_me is False
|
||||
|
||||
def test_owned_by_me_true_accepted(self):
|
||||
request = ListDashboardsRequest(owned_by_me=True)
|
||||
assert request.owned_by_me is True
|
||||
|
||||
def test_owned_by_me_combined_with_filters(self):
|
||||
request = ListDashboardsRequest(
|
||||
owned_by_me=True,
|
||||
filters=[DashboardFilter(col="published", opr="eq", value=True)],
|
||||
)
|
||||
assert request.owned_by_me is True
|
||||
assert len(request.filters) == 1
|
||||
|
||||
def test_owned_by_me_with_search_raises(self):
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="owned_by_me"):
|
||||
ListDashboardsRequest(owned_by_me=True, search="My dashboards")
|
||||
|
||||
def test_owned_by_me_and_created_by_me_allowed(self):
|
||||
"""Both flags together are valid (OR logic — creator or owner)."""
|
||||
request = ListDashboardsRequest(owned_by_me=True, created_by_me=True)
|
||||
assert request.owned_by_me is True
|
||||
assert request.created_by_me is True
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_no_arguments(mock_list, mcp_server):
|
||||
"""Regression test: list_dashboards must accept zero arguments without raising
|
||||
pydantic_core.ValidationError: Missing required argument: request."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_dashboards", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "dashboards" in data
|
||||
|
||||
831
tests/unit_tests/mcp_service/dataset/tool/test_query_dataset.py
Normal file
831
tests/unit_tests/mcp_service/dataset/tool/test_query_dataset.py
Normal file
@@ -0,0 +1,831 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for the query_dataset MCP tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client, FastMCP
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
query_dataset_module = importlib.import_module(
|
||||
"superset.mcp_service.dataset.tool.query_dataset"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server() -> FastMCP:
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth() -> Generator[MagicMock, None, None]:
|
||||
"""Mock authentication and metadata access for all tests."""
|
||||
with (
|
||||
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _make_column(name: str, is_dttm: bool = False) -> MagicMock:
|
||||
"""Build a mock SqlaTable column with the given name and datetime flag."""
|
||||
col = MagicMock()
|
||||
col.column_name = name
|
||||
col.is_dttm = is_dttm
|
||||
col.verbose_name = None
|
||||
col.type = "VARCHAR"
|
||||
col.groupby = True
|
||||
col.filterable = True
|
||||
col.description = None
|
||||
return col
|
||||
|
||||
|
||||
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
|
||||
"""Build a mock SqlMetric with the given name and SQL expression."""
|
||||
metric = MagicMock()
|
||||
metric.metric_name = name
|
||||
metric.verbose_name = None
|
||||
metric.expression = expression
|
||||
metric.description = None
|
||||
metric.d3format = None
|
||||
return metric
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
dataset_id: int = 1,
|
||||
table_name: str = "orders",
|
||||
columns: list[Any] | None = None,
|
||||
metrics: list[Any] | None = None,
|
||||
main_dttm_col: str | None = None,
|
||||
) -> MagicMock:
|
||||
"""Build a mock SqlaTable dataset with default columns and metrics."""
|
||||
ds = MagicMock()
|
||||
ds.id = dataset_id
|
||||
ds.table_name = table_name
|
||||
ds.uuid = f"test-uuid-{dataset_id}"
|
||||
ds.main_dttm_col = main_dttm_col
|
||||
ds.database = MagicMock()
|
||||
ds.database.database_name = "examples"
|
||||
ds.columns = columns or [
|
||||
_make_column("category"),
|
||||
_make_column("region"),
|
||||
_make_column("order_date", is_dttm=True),
|
||||
]
|
||||
ds.metrics = metrics or [
|
||||
_make_metric("count", "COUNT(*)"),
|
||||
_make_metric("total_revenue", "SUM(revenue)"),
|
||||
]
|
||||
return ds
|
||||
|
||||
|
||||
def _mock_command_result(
|
||||
data: list[dict[str, Any]] | None = None,
|
||||
colnames: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the result dict that ChartDataCommand.run() returns."""
|
||||
data = data or [
|
||||
{"category": "Electronics", "count": 42},
|
||||
{"category": "Clothing", "count": 17},
|
||||
]
|
||||
colnames = colnames or ["category", "count"]
|
||||
return {
|
||||
"queries": [
|
||||
{
|
||||
"data": data,
|
||||
"colnames": colnames,
|
||||
"rowcount": len(data),
|
||||
"cache_key": "abc123",
|
||||
"is_cached": False,
|
||||
"cached_dttm": None,
|
||||
"cache_timeout": 300,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_success(mcp_server: FastMCP) -> None:
|
||||
"""Happy path: metrics + columns returns data."""
|
||||
dataset = _make_dataset()
|
||||
result_data = _mock_command_result()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"columns": ["category"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["dataset_id"] == 1
|
||||
assert data["dataset_name"] == "orders"
|
||||
assert data["row_count"] == 2
|
||||
assert len(data["data"]) == 2
|
||||
assert data["data"][0]["category"] == "Electronics"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_not_found(mcp_server: FastMCP) -> None:
|
||||
"""Dataset ID that doesn't exist returns error."""
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=None,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 999,
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "NotFound"
|
||||
assert "999" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_invalid_metric(mcp_server: FastMCP) -> None:
|
||||
"""Unknown metric name returns validation error with suggestions."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["countt"], # typo
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "countt" in data["error"]
|
||||
# Should suggest "count" as a close match
|
||||
assert "count" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_invalid_column(mcp_server: FastMCP) -> None:
|
||||
"""Unknown column name returns validation error."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"columns": ["nonexistent_col"],
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "nonexistent_col" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_no_metrics_no_columns(mcp_server: FastMCP) -> None:
|
||||
"""Providing neither metrics nor columns raises validation error."""
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError, match="metrics.*columns"):
|
||||
await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": [],
|
||||
"columns": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_with_time_range(mcp_server: FastMCP) -> None:
|
||||
"""time_range is converted to TEMPORAL_RANGE filter + granularity."""
|
||||
dataset = _make_dataset(main_dttm_col="order_date")
|
||||
result_data = _mock_command_result()
|
||||
captured_queries: list[dict[str, Any]] = []
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured_queries.extend(kwargs.get("queries", []))
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
side_effect=capture_create,
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"time_range": "Last 7 days",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert len(captured_queries) == 1
|
||||
query_dict = captured_queries[0]
|
||||
# Should have TEMPORAL_RANGE filter
|
||||
temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"]
|
||||
assert len(temporal_filters) == 1
|
||||
assert temporal_filters[0]["col"] == "order_date"
|
||||
assert temporal_filters[0]["val"] == "Last 7 days"
|
||||
# Should set granularity
|
||||
assert query_dict["granularity"] == "order_date"
|
||||
# applied_filters in response must include the synthesized TEMPORAL_RANGE filter
|
||||
data = json.loads(result.content[0].text)
|
||||
resp_filters = data["applied_filters"]
|
||||
temporal_resp = [f for f in resp_filters if f["op"] == "TEMPORAL_RANGE"]
|
||||
assert len(temporal_resp) == 1
|
||||
assert temporal_resp[0]["col"] == "order_date"
|
||||
assert temporal_resp[0]["val"] == "Last 7 days"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_time_range_no_temporal_column(mcp_server: FastMCP) -> None:
|
||||
"""time_range without a temporal column returns error."""
|
||||
dataset = _make_dataset(main_dttm_col=None)
|
||||
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"time_range": "Last 7 days",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "temporal column" in data["error"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_with_filters(mcp_server: FastMCP) -> None:
|
||||
"""User-provided filters are passed through to the query."""
|
||||
dataset = _make_dataset()
|
||||
result_data = _mock_command_result()
|
||||
captured_queries: list[dict[str, Any]] = []
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured_queries.extend(kwargs.get("queries", []))
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
side_effect=capture_create,
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"filters": [
|
||||
{"col": "category", "op": "==", "val": "Electronics"}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert len(captured_queries) == 1
|
||||
filters = captured_queries[0]["filters"]
|
||||
assert len(filters) == 1
|
||||
assert filters[0]["col"] == "category"
|
||||
assert filters[0]["op"] == "=="
|
||||
assert filters[0]["val"] == "Electronics"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_empty_results(mcp_server: FastMCP) -> None:
|
||||
"""Query that returns no data gives a response with row_count=0."""
|
||||
dataset = _make_dataset()
|
||||
empty_result = {
|
||||
"queries": [
|
||||
{
|
||||
"data": [],
|
||||
"colnames": [],
|
||||
"rowcount": 0,
|
||||
"is_cached": False,
|
||||
"cached_dttm": None,
|
||||
"cache_timeout": 300,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=empty_result,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["row_count"] == 0
|
||||
assert data["data"] == []
|
||||
assert "no data" in data["summary"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_by_uuid(mcp_server: FastMCP) -> None:
|
||||
"""UUID-based lookup works."""
|
||||
dataset = _make_dataset()
|
||||
result_data = _mock_command_result()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
) as mock_resolve,
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": "a1b2c3d4-5678-90ab-cdef-1234567890ab",
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the resolve function was called with the UUID
|
||||
mock_resolve.assert_called_once()
|
||||
call_args = mock_resolve.call_args
|
||||
assert call_args[0][0] == "a1b2c3d4-5678-90ab-cdef-1234567890ab"
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["dataset_id"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_permission_denied(mcp_server: FastMCP) -> None:
|
||||
"""Permission denied from ChartDataCommand.validate() returns error."""
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
dataset = _make_dataset()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
side_effect=SupersetSecurityException(
|
||||
SupersetError(
|
||||
message="Access denied",
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.WARNING,
|
||||
)
|
||||
),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "QueryError"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_order_by_valid(mcp_server: FastMCP) -> None:
|
||||
"""order_by with valid column/metric names passes through."""
|
||||
dataset = _make_dataset()
|
||||
result_data = _mock_command_result()
|
||||
captured_queries: list[dict[str, Any]] = []
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured_queries.extend(kwargs.get("queries", []))
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
side_effect=capture_create,
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"columns": ["category"],
|
||||
"order_by": ["count"],
|
||||
"order_desc": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert len(captured_queries) == 1
|
||||
orderby = captured_queries[0].get("orderby", [])
|
||||
assert len(orderby) == 1
|
||||
assert orderby[0][0] == "count"
|
||||
# order_desc=True -> ascending=False
|
||||
assert orderby[0][1] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_order_by_invalid(mcp_server: FastMCP) -> None:
|
||||
"""order_by with an unknown name returns validation error."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"order_by": ["nonexistent"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "nonexistent" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_time_column_override(mcp_server: FastMCP) -> None:
|
||||
"""Explicit time_column overrides dataset main_dttm_col."""
|
||||
dataset = _make_dataset(main_dttm_col="order_date")
|
||||
result_data = _mock_command_result()
|
||||
captured_queries: list[dict[str, Any]] = []
|
||||
|
||||
def capture_create(**kwargs):
|
||||
captured_queries.extend(kwargs.get("queries", []))
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
side_effect=capture_create,
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"time_range": "Last 30 days",
|
||||
"time_column": "order_date",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert len(captured_queries) == 1
|
||||
query_dict = captured_queries[0]
|
||||
assert query_dict["granularity"] == "order_date"
|
||||
temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"]
|
||||
assert temporal_filters[0]["col"] == "order_date"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_non_dttm_time_column_warns(mcp_server: FastMCP) -> None:
|
||||
"""Using a non-datetime column for time_range produces a warning."""
|
||||
dataset = _make_dataset(main_dttm_col=None)
|
||||
result_data = _mock_command_result()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
|
||||
),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
|
||||
return_value=result_data,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory.create",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"time_range": "Last 7 days",
|
||||
"time_column": "category",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert len(data["warnings"]) > 0
|
||||
assert "not marked as a datetime" in data["warnings"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_invalid_filter_column(mcp_server: FastMCP) -> None:
|
||||
"""Filter on a column that doesn't exist returns validation error."""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
"metrics": ["count"],
|
||||
"filters": [
|
||||
{
|
||||
"col": "nonexistent",
|
||||
"op": "==",
|
||||
"val": "test",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "nonexistent" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_metadata_access_denied_no_suggestions(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Users without data-model metadata access cannot probe column/metric names.
|
||||
|
||||
The privacy gate must fire before the validation step that returns close-match
|
||||
suggestions, so restricted users cannot enumerate schema details via typos.
|
||||
"""
|
||||
dataset = _make_dataset()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"_resolve_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
query_dataset_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 1,
|
||||
# Typo that would normally trigger close-match suggestions
|
||||
"metrics": ["countt"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
# Must be denied before returning any schema suggestions
|
||||
assert data["error_type"] == "DataModelMetadataRestricted"
|
||||
# Must NOT contain column/metric name suggestions
|
||||
assert "countt" not in data.get("error", "")
|
||||
assert "count" not in data.get("error", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_dataset_metadata_access_denied_nonexistent_dataset(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Metadata-restricted users must not be able to probe dataset existence.
|
||||
|
||||
The privacy gate fires before the DAO lookup, so a restricted caller
|
||||
always receives DataModelMetadataRestricted — never NotFound — regardless
|
||||
of whether the requested dataset ID exists.
|
||||
"""
|
||||
with patch.object(
|
||||
query_dataset_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=False,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"query_dataset",
|
||||
{
|
||||
"request": {
|
||||
# Use a dataset_id that does not exist
|
||||
"dataset_id": 999999,
|
||||
"metrics": ["count"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
# Must receive restricted error, not a NotFound that leaks existence
|
||||
assert data["error_type"] == "DataModelMetadataRestricted"
|
||||
assert data["error_type"] != "NotFound"
|
||||
@@ -372,6 +372,43 @@ def test_mcp_auth_hook_preserves_g_user_in_request_context(app) -> None:
|
||||
assert result == "middleware_user"
|
||||
|
||||
|
||||
def test_mcp_auth_hook_removes_stale_db_session_in_sync_wrapper(app) -> None:
|
||||
"""sync_wrapper calls db.session.remove() BEFORE get_user_from_request().
|
||||
|
||||
Thread pool workers reuse threads across requests; db.session is
|
||||
thread-local and may be bound to a different tenant's DB engine from a
|
||||
prior request. Removing it before user lookup ensures a fresh session is
|
||||
created for the current request.
|
||||
|
||||
The ordering is critical: if remove() were called after user lookup,
|
||||
the stale session binding would already have caused a mismatch error.
|
||||
"""
|
||||
fresh_user = _make_mock_user("fresh")
|
||||
|
||||
def dummy_tool():
|
||||
"""Dummy tool."""
|
||||
return g.user.username
|
||||
|
||||
wrapped = mcp_auth_hook(dummy_tool)
|
||||
|
||||
with app.test_request_context():
|
||||
g.user = fresh_user
|
||||
with patch("superset.extensions.db") as mock_db:
|
||||
|
||||
def _assert_remove_already_called() -> MagicMock:
|
||||
"""Verify remove() was called before user resolution runs."""
|
||||
mock_db.session.remove.assert_called_once_with()
|
||||
return fresh_user
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.auth.get_user_from_request",
|
||||
side_effect=_assert_remove_already_called,
|
||||
):
|
||||
result = wrapped()
|
||||
|
||||
assert result == "fresh"
|
||||
|
||||
|
||||
# -- default_user_resolver --
|
||||
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ def test_cache_key_changes_for_new_query_object_same_params():
|
||||
assert query_object2.cache_key() == cache_key1
|
||||
|
||||
|
||||
@patch("superset.common.query_object.feature_flag_manager")
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
def test_cache_key_cache_query_by_user_on_no_datasource(feature_flag_mock):
|
||||
"""
|
||||
When CACHE_QUERY_BY_USER flag is on and there is no datasource,
|
||||
@@ -112,7 +112,7 @@ def test_cache_key_cache_query_by_user_on_no_datasource(feature_flag_mock):
|
||||
assert query_object.cache_key() == cache_key
|
||||
|
||||
|
||||
@patch("superset.common.query_object.feature_flag_manager")
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
@patch("superset.common.query_object.logger")
|
||||
def test_cache_key_cache_query_by_user_on_no_user(logger_mock, feature_flag_mock):
|
||||
"""
|
||||
@@ -140,16 +140,13 @@ def test_cache_key_cache_query_by_user_on_no_user(logger_mock, feature_flag_mock
|
||||
logger_mock.debug.assert_called()
|
||||
|
||||
|
||||
@patch("superset.common.query_object.feature_flag_manager")
|
||||
@patch("superset.common.query_object.logger")
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
@patch("superset.utils.cache_keys.logger")
|
||||
def test_cache_key_cache_query_by_user_on_with_user(logger_mock, feature_flag_mock):
|
||||
"""
|
||||
When the same user is requesting a cache key with CACHE_QUERY_BY_USER
|
||||
flag on, the key will be the same
|
||||
"""
|
||||
# Configure logger to enable DEBUG level for isEnabledFor check
|
||||
logger_mock.isEnabledFor.return_value = True
|
||||
|
||||
datasource = SqlaTable(
|
||||
table_name="test_table",
|
||||
columns=[],
|
||||
@@ -167,17 +164,17 @@ def test_cache_key_cache_query_by_user_on_with_user(logger_mock, feature_flag_mo
|
||||
cache_key1 = query_object.cache_key()
|
||||
assert query_object.cache_key() == cache_key1
|
||||
|
||||
# Should have both impersonation and cache key generation logs
|
||||
# Should have impersonation log emitted by the cache_keys helper
|
||||
logger_mock.debug.assert_has_calls(
|
||||
[
|
||||
call("Adding impersonation key to QueryObject cache dict: %s", "test_user"),
|
||||
call("Adding impersonation key to cache dict: %s", "test_user"),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.common.query_object.feature_flag_manager")
|
||||
@patch("superset.common.query_object.logger")
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
@patch("superset.utils.cache_keys.logger")
|
||||
def test_cache_key_cache_query_by_user_on_with_different_user(
|
||||
logger_mock, feature_flag_mock
|
||||
):
|
||||
@@ -185,9 +182,6 @@ def test_cache_key_cache_query_by_user_on_with_different_user(
|
||||
When two different users are requesting a cache key with CACHE_QUERY_BY_USER
|
||||
flag on, the key will be different
|
||||
"""
|
||||
# Configure logger to enable DEBUG level for isEnabledFor check
|
||||
logger_mock.isEnabledFor.return_value = True
|
||||
|
||||
datasource = SqlaTable(
|
||||
table_name="test_table",
|
||||
columns=[],
|
||||
@@ -209,21 +203,17 @@ def test_cache_key_cache_query_by_user_on_with_different_user(
|
||||
|
||||
assert cache_key1 != cache_key2
|
||||
|
||||
# Should have both impersonation and cache key generation logs (any order)
|
||||
# Should have impersonation logs emitted by the cache_keys helper
|
||||
logger_mock.debug.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
"Adding impersonation key to QueryObject cache dict: %s", "test_user1"
|
||||
),
|
||||
call(
|
||||
"Adding impersonation key to QueryObject cache dict: %s", "test_user2"
|
||||
),
|
||||
call("Adding impersonation key to cache dict: %s", "test_user1"),
|
||||
call("Adding impersonation key to cache dict: %s", "test_user2"),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.common.query_object.feature_flag_manager")
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
@patch("superset.common.query_object.logger")
|
||||
def test_cache_key_cache_impersonation_on_no_user(logger_mock, feature_flag_mock):
|
||||
"""
|
||||
@@ -251,7 +241,7 @@ def test_cache_key_cache_impersonation_on_no_user(logger_mock, feature_flag_mock
|
||||
logger_mock.debug.assert_called()
|
||||
|
||||
|
||||
@patch("superset.common.query_object.feature_flag_manager")
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
@patch("superset.common.query_object.logger")
|
||||
def test_cache_key_cache_impersonation_on_with_user(logger_mock, feature_flag_mock):
|
||||
"""
|
||||
@@ -290,7 +280,7 @@ def test_cache_key_cache_impersonation_on_with_user(logger_mock, feature_flag_mo
|
||||
assert len(impersonation_calls) == 0
|
||||
|
||||
|
||||
@patch("superset.common.query_object.feature_flag_manager")
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
@patch("superset.common.query_object.logger")
|
||||
def test_cache_key_cache_impersonation_on_with_different_user(
|
||||
logger_mock, feature_flag_mock
|
||||
@@ -335,8 +325,8 @@ def test_cache_key_cache_impersonation_on_with_different_user(
|
||||
assert len(impersonation_calls) == 0
|
||||
|
||||
|
||||
@patch("superset.common.query_object.feature_flag_manager")
|
||||
@patch("superset.common.query_object.logger")
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
@patch("superset.utils.cache_keys.logger")
|
||||
def test_cache_key_cache_impersonation_on_with_different_user_and_db_impersonation(
|
||||
logger_mock,
|
||||
feature_flag_mock,
|
||||
@@ -346,9 +336,6 @@ def test_cache_key_cache_impersonation_on_with_different_user_and_db_impersonati
|
||||
flag on, and cache_impersonation is enabled on the database,
|
||||
the keys will be different
|
||||
"""
|
||||
# Configure logger to enable DEBUG level for isEnabledFor check
|
||||
logger_mock.isEnabledFor.return_value = True
|
||||
|
||||
datasource = SqlaTable(
|
||||
table_name="test_table",
|
||||
columns=[],
|
||||
@@ -374,15 +361,11 @@ def test_cache_key_cache_impersonation_on_with_different_user_and_db_impersonati
|
||||
|
||||
assert cache_key1 != cache_key2
|
||||
|
||||
# Should have both impersonation and cache key generation logs (any order)
|
||||
# Should have impersonation logs emitted by the cache_keys helper
|
||||
logger_mock.debug.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
"Adding impersonation key to QueryObject cache dict: %s", "test_user1"
|
||||
),
|
||||
call(
|
||||
"Adding impersonation key to QueryObject cache dict: %s", "test_user2"
|
||||
),
|
||||
call("Adding impersonation key to cache dict: %s", "test_user1"),
|
||||
call("Adding impersonation key to cache dict: %s", "test_user2"),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
111
tests/unit_tests/test_viz_cache_key.py
Normal file
111
tests/unit_tests/test_viz_cache_key.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Behavioral tests for ``viz.BaseViz.cache_key`` covering per-user cache-key
|
||||
inclusion.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset import viz
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import override_user
|
||||
|
||||
QUERY_OBJ: dict[str, Any] = {"row_limit": 100, "from_dttm": None, "to_dttm": None}
|
||||
|
||||
|
||||
def _viz_for(database: Database) -> viz.BaseViz:
|
||||
datasource = SqlaTable(
|
||||
table_name="t",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
main_dttm_col=None,
|
||||
database=database,
|
||||
)
|
||||
return viz.BaseViz(datasource=datasource, form_data={"viz_type": "table"})
|
||||
|
||||
|
||||
def test_no_per_user_opt_in_keys_match_across_users():
|
||||
"""
|
||||
Without any per-user caching opt-in, two different users on the same
|
||||
database/query must produce the *same* cache key (regression guard — we
|
||||
must not accidentally make every cache key per-user).
|
||||
"""
|
||||
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
|
||||
obj = _viz_for(database)
|
||||
|
||||
with override_user(User(username="alice")):
|
||||
key_a = obj.cache_key(QUERY_OBJ)
|
||||
with override_user(User(username="bob")):
|
||||
key_b = obj.cache_key(QUERY_OBJ)
|
||||
|
||||
assert key_a == key_b
|
||||
|
||||
|
||||
def test_per_user_caching_in_extra_yields_distinct_keys_per_user():
|
||||
"""
|
||||
With ``per_user_caching: true`` set on the database, two different users
|
||||
must produce *different* cache keys for the same query.
|
||||
"""
|
||||
database = Database(
|
||||
database_name="d",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
extra='{"per_user_caching": true}',
|
||||
)
|
||||
obj = _viz_for(database)
|
||||
|
||||
with override_user(User(username="alice")):
|
||||
key_a = obj.cache_key(QUERY_OBJ)
|
||||
with override_user(User(username="bob")):
|
||||
key_b = obj.cache_key(QUERY_OBJ)
|
||||
|
||||
assert key_a != key_b
|
||||
|
||||
|
||||
def test_same_user_same_query_idempotent():
|
||||
database = Database(
|
||||
database_name="d",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
extra='{"per_user_caching": true}',
|
||||
)
|
||||
obj = _viz_for(database)
|
||||
|
||||
with override_user(User(username="alice")):
|
||||
assert obj.cache_key(QUERY_OBJ) == obj.cache_key(QUERY_OBJ)
|
||||
|
||||
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
def test_cache_query_by_user_flag_yields_distinct_keys(feature_flag_mock):
|
||||
"""
|
||||
Global ``CACHE_QUERY_BY_USER`` flag also reaches the legacy viz path.
|
||||
"""
|
||||
feature_flag_mock.is_feature_enabled.side_effect = (
|
||||
lambda feature=None: feature == "CACHE_QUERY_BY_USER"
|
||||
)
|
||||
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
|
||||
obj = _viz_for(database)
|
||||
|
||||
with override_user(User(username="alice")):
|
||||
key_a = obj.cache_key(QUERY_OBJ)
|
||||
with override_user(User(username="bob")):
|
||||
key_b = obj.cache_key(QUERY_OBJ)
|
||||
|
||||
assert key_a != key_b
|
||||
107
tests/unit_tests/utils/test_impersonation_cache_key.py
Normal file
107
tests/unit_tests/utils/test_impersonation_cache_key.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.models.core import Database
|
||||
from superset.utils.cache_keys import add_impersonation_cache_key_if_needed
|
||||
from superset.utils.core import override_user
|
||||
|
||||
|
||||
def _flag(name: str):
|
||||
"""Build a feature-flag side_effect that returns True only for ``name``."""
|
||||
|
||||
def side_effect(feature=None):
|
||||
return feature == name
|
||||
|
||||
return side_effect
|
||||
|
||||
|
||||
def _run(database: Database) -> dict[str, Any]:
|
||||
"""Run the helper against a fresh dict and return that dict."""
|
||||
cache_dict: dict[str, Any] = {}
|
||||
add_impersonation_cache_key_if_needed(database, cache_dict)
|
||||
return cache_dict
|
||||
|
||||
|
||||
def test_no_per_user_caching_yields_no_key():
|
||||
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
|
||||
with override_user(User(username="u")):
|
||||
assert "impersonation_key" not in _run(database)
|
||||
|
||||
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
def test_cache_query_by_user_adds_username(feature_flag_mock):
|
||||
feature_flag_mock.is_feature_enabled.side_effect = _flag("CACHE_QUERY_BY_USER")
|
||||
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
|
||||
with override_user(User(username="alice")):
|
||||
assert _run(database)["impersonation_key"] == "alice"
|
||||
|
||||
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
def test_cache_query_by_user_distinct_per_user(feature_flag_mock):
|
||||
feature_flag_mock.is_feature_enabled.side_effect = _flag("CACHE_QUERY_BY_USER")
|
||||
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
|
||||
with override_user(User(username="alice")):
|
||||
key_a = _run(database)["impersonation_key"]
|
||||
with override_user(User(username="bob")):
|
||||
key_b = _run(database)["impersonation_key"]
|
||||
assert key_a != key_b
|
||||
|
||||
|
||||
@patch("superset.utils.cache_keys.feature_flag_manager")
|
||||
def test_cache_impersonation_requires_database_flag(feature_flag_mock):
|
||||
"""
|
||||
CACHE_IMPERSONATION alone is not enough; ``database.impersonate_user`` must
|
||||
also be set on the database for the per-user key to apply.
|
||||
"""
|
||||
feature_flag_mock.is_feature_enabled.side_effect = _flag("CACHE_IMPERSONATION")
|
||||
|
||||
db_no_impersonation = Database(database_name="d", sqlalchemy_uri="sqlite://")
|
||||
db_with_impersonation = Database(
|
||||
database_name="d", sqlalchemy_uri="sqlite://", impersonate_user=True
|
||||
)
|
||||
|
||||
with override_user(User(username="alice")):
|
||||
assert "impersonation_key" not in _run(db_no_impersonation)
|
||||
assert _run(db_with_impersonation)["impersonation_key"] == "alice"
|
||||
|
||||
|
||||
def test_per_user_caching_in_extra_json_enables_key():
|
||||
database = Database(
|
||||
database_name="d",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
extra='{"per_user_caching": true}',
|
||||
)
|
||||
with override_user(User(username="alice")):
|
||||
assert _run(database)["impersonation_key"] == "alice"
|
||||
|
||||
|
||||
def test_no_user_yields_no_key(app_context): # noqa: ARG001
|
||||
"""
|
||||
With no logged-in user, the engine spec returns None even when per-user
|
||||
caching is enabled — there's no identity to key on.
|
||||
"""
|
||||
database = Database(
|
||||
database_name="d",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
extra='{"per_user_caching": true}',
|
||||
)
|
||||
# No override_user — g.user is unset
|
||||
assert "impersonation_key" not in _run(database)
|
||||
Reference in New Issue
Block a user