mirror of
https://github.com/apache/superset.git
synced 2026-06-10 01:59:17 +00:00
Compare commits
11 Commits
enxdev/cha
...
aminghader
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
691fd6b744 | ||
|
|
af1d57ed60 | ||
|
|
46e5b96bc0 | ||
|
|
b96ef27b90 | ||
|
|
75789601e7 | ||
|
|
73684eeb25 | ||
|
|
bb4f4095d1 | ||
|
|
c474cadaa0 | ||
|
|
df7d8afa22 | ||
|
|
db0de419af | ||
|
|
02b8aa4a55 |
@@ -170,6 +170,7 @@ Plugins (Admin only):
|
||||
Dataset Management:
|
||||
- list_datasets: List datasets with advanced filters (1-based pagination)
|
||||
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
|
||||
- create_dataset: Register a physical table as a dataset against an existing DB connection (requires write access)
|
||||
- create_virtual_dataset: Save a SQL query as a virtual dataset for charting (requires write access)
|
||||
- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart
|
||||
|
||||
@@ -425,7 +426,7 @@ Input format:
|
||||
{_feature_availability}Permission Awareness:
|
||||
{_instance_info_role_bullet}- ALWAYS check the user's roles BEFORE suggesting write operations (creating datasets,
|
||||
charts, or dashboards). SQL execution is a separate permission — see execute_sql below.
|
||||
- Write tools (generate_chart, generate_dashboard, update_chart, create_virtual_dataset,
|
||||
- Write tools (generate_chart, generate_dashboard, update_chart, create_dataset, create_virtual_dataset,
|
||||
save_sql_query, add_chart_to_existing_dashboard, update_chart_preview) require write
|
||||
permissions. These tools are only listed for users who have the necessary access.
|
||||
If a write tool does not appear in the tool list, the current user lacks write access.
|
||||
@@ -634,9 +635,9 @@ def create_mcp_app(
|
||||
# Create default MCP instance for backward compatibility
|
||||
mcp = create_mcp_app()
|
||||
|
||||
# Initialize MCP dependency injection BEFORE importing tools/prompts
|
||||
# This replaces the abstract @tool and @prompt decorators in superset_core.api.mcp
|
||||
# with concrete implementations that can register with the mcp instance
|
||||
# Initialize MCP dependency injection BEFORE importing tools/prompts.
|
||||
# Replaces the stub @tool/@prompt decorators in superset_core.mcp.decorators
|
||||
# with concrete implementations bound to this mcp instance.
|
||||
from superset.core.mcp.core_mcp_injection import ( # noqa: E402
|
||||
initialize_core_mcp_dependencies,
|
||||
)
|
||||
@@ -661,6 +662,7 @@ warnings.filterwarnings(
|
||||
module=r"google\..*",
|
||||
)
|
||||
|
||||
|
||||
# Import all MCP tools to register them with the mcp instance
|
||||
# NOTE: Always add new tool imports here when creating new MCP tools.
|
||||
# Tools use the @tool decorator from `superset-core` and register automatically
|
||||
@@ -709,6 +711,7 @@ from superset.mcp_service.database.tool import ( # noqa: F401, E402
|
||||
list_databases,
|
||||
)
|
||||
from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||
create_dataset,
|
||||
create_virtual_dataset,
|
||||
get_dataset_info,
|
||||
list_datasets,
|
||||
|
||||
@@ -324,6 +324,48 @@ class GetDatasetInfoRequest(MetadataCacheControl):
|
||||
]
|
||||
|
||||
|
||||
class CreateDatasetRequest(BaseModel):
|
||||
"""Request schema for create_dataset to register a physical table as a dataset."""
|
||||
|
||||
database_id: Annotated[
|
||||
int,
|
||||
Field(
|
||||
description="ID of the database connection to register the table against"
|
||||
),
|
||||
]
|
||||
schema: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Schema (namespace) where the table lives, e.g. 'public'. "
|
||||
"Omit or pass None for databases without schema namespaces (e.g. SQLite).",
|
||||
),
|
||||
]
|
||||
catalog: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Catalog where the table lives. Omit for databases without "
|
||||
"catalog support.",
|
||||
),
|
||||
]
|
||||
table_name: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
description="Name of the physical table to register as a dataset",
|
||||
),
|
||||
]
|
||||
owners: Annotated[
|
||||
List[int] | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Optional list of owner user IDs. "
|
||||
"Defaults to the calling user.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class CreateVirtualDatasetRequest(BaseModel):
|
||||
"""Request schema for create_virtual_dataset."""
|
||||
|
||||
|
||||
@@ -15,14 +15,16 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .create_dataset import create_dataset
|
||||
from .create_virtual_dataset import create_virtual_dataset
|
||||
from .get_dataset_info import get_dataset_info
|
||||
from .list_datasets import list_datasets
|
||||
from .query_dataset import query_dataset
|
||||
|
||||
__all__ = [
|
||||
"create_dataset",
|
||||
"create_virtual_dataset",
|
||||
"list_datasets",
|
||||
"get_dataset_info",
|
||||
"list_datasets",
|
||||
"query_dataset",
|
||||
]
|
||||
|
||||
184
superset/mcp_service/dataset/tool/create_dataset.py
Normal file
184
superset/mcp_service/dataset/tool/create_dataset.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Create dataset FastMCP tool
|
||||
|
||||
Registers a physical table as a Superset dataset against an existing
|
||||
database connection — the programmatic equivalent of Data → Datasets → +Dataset.
|
||||
Returns the same DatasetInfo shape as get_dataset_info so the caller can feed
|
||||
the resulting dataset_id directly into generate_chart.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.dataset.schemas import (
|
||||
CreateDatasetRequest,
|
||||
DatasetError,
|
||||
DatasetInfo,
|
||||
serialize_dataset_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _classify_invalid_error(exc: object) -> DatasetError:
|
||||
"""Map DatasetInvalidError sub-exceptions to typed DatasetError responses."""
|
||||
classnames = exc.get_list_classnames() # type: ignore[attr-defined]
|
||||
messages = exc.normalized_messages() # type: ignore[attr-defined]
|
||||
if "DatasetExistsValidationError" in classnames:
|
||||
return DatasetError.create(error=str(messages), error_type="DatasetExistsError")
|
||||
if "TableNotFoundValidationError" in classnames:
|
||||
return DatasetError.create(error=str(messages), error_type="TableNotFoundError")
|
||||
return DatasetError.create(error=str(messages), error_type="ValidationError")
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["mutate"],
|
||||
class_permission_name="Dataset",
|
||||
method_permission_name="write",
|
||||
annotations=ToolAnnotations(
|
||||
title="Register physical table as dataset",
|
||||
readOnlyHint=False,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def create_dataset(
|
||||
request: CreateDatasetRequest, ctx: Context
|
||||
) -> DatasetInfo | DatasetError:
|
||||
"""Register a physical table as a Superset dataset.
|
||||
|
||||
Wraps POST /api/v1/dataset/ — the same endpoint the UI uses when you click
|
||||
Data → Datasets → +Dataset. Returns full dataset metadata (same shape as
|
||||
get_dataset_info) so you can pass the resulting dataset_id straight into
|
||||
generate_chart.
|
||||
|
||||
Required fields:
|
||||
- database_id: ID of the existing database connection
|
||||
- table_name: Exact name of the physical table to register
|
||||
|
||||
Optional fields:
|
||||
- schema: Schema/namespace where the table lives (e.g. "public"). Omit for
|
||||
databases without schema namespaces (e.g. SQLite).
|
||||
- catalog: Catalog where the table lives. Omit for databases without catalog
|
||||
support.
|
||||
- owners: List of user IDs to set as owners (defaults to calling user)
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders"
|
||||
}
|
||||
```
|
||||
|
||||
Returns DatasetInfo on success or DatasetError on failure.
|
||||
Use list_databases to find the correct database_id.
|
||||
"""
|
||||
# Normalize schema and table_name: strip whitespace, treat blank schema as None
|
||||
schema = request.schema.strip() if request.schema else None
|
||||
table_name = request.table_name.strip()
|
||||
|
||||
await ctx.info(
|
||||
"Registering physical table as dataset: database_id=%s, table=%s.%s"
|
||||
% (request.database_id, schema, table_name)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.commands.dataset.create import CreateDatasetCommand
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetCreateFailedError,
|
||||
DatasetInvalidError,
|
||||
)
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.extensions import security_manager
|
||||
from superset.sql.parse import Table
|
||||
|
||||
# Normalize catalog: strip whitespace, treat blank as None
|
||||
catalog = request.catalog.strip() if request.catalog else None
|
||||
|
||||
# Look up the database so we can enforce table-level access before
|
||||
# forwarding to CreateDatasetCommand, which only checks SQL-path access.
|
||||
database = DatabaseDAO.find_by_id(request.database_id)
|
||||
if database is None:
|
||||
return DatasetError.create(
|
||||
error=f"Database with id={request.database_id} not found",
|
||||
error_type="DatabaseNotFoundError",
|
||||
)
|
||||
|
||||
# Enforce table-level access: prevents users with Dataset.write from
|
||||
# registering tables in databases they cannot read.
|
||||
table_obj = Table(table_name, schema, catalog)
|
||||
try:
|
||||
security_manager.raise_for_access(database=database, table=table_obj)
|
||||
except SupersetSecurityException as exc:
|
||||
await ctx.warning("Access denied to table %s: %s" % (table_obj, str(exc)))
|
||||
return DatasetError.create(
|
||||
error=str(exc),
|
||||
error_type="AccessDeniedError",
|
||||
)
|
||||
|
||||
dataset_properties: dict[str, object] = {
|
||||
"database": request.database_id,
|
||||
"table_name": table_name,
|
||||
}
|
||||
if schema is not None:
|
||||
dataset_properties["schema"] = schema
|
||||
if catalog is not None:
|
||||
dataset_properties["catalog"] = catalog
|
||||
if request.owners is not None:
|
||||
dataset_properties["owners"] = request.owners
|
||||
|
||||
with event_logger.log_context(action="mcp.create_dataset.create"):
|
||||
dataset = CreateDatasetCommand(dataset_properties).run()
|
||||
|
||||
result = serialize_dataset_object(dataset)
|
||||
if result is None:
|
||||
return DatasetError.create(
|
||||
error="Dataset was created but could not be serialized",
|
||||
error_type="SerializationError",
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Dataset registered: id=%s, table=%s.%s" % (dataset.id, schema, table_name)
|
||||
)
|
||||
return result
|
||||
|
||||
except DatasetInvalidError as exc:
|
||||
# CreateDatasetCommand.validate() collects DatasetExistsValidationError and
|
||||
# TableNotFoundValidationError into DatasetInvalidError.exceptions, never
|
||||
# raising them directly. Inspect the wrapped class names for a typed response.
|
||||
error_response = _classify_invalid_error(exc)
|
||||
await ctx.warning(
|
||||
"Dataset validation failed (%s): %s"
|
||||
% (error_response.error_type, error_response.error)
|
||||
)
|
||||
return error_response
|
||||
except DatasetCreateFailedError as exc:
|
||||
await ctx.error("Dataset creation failed: %s" % (str(exc),))
|
||||
return DatasetError.create(error=str(exc), error_type="CreateFailedError")
|
||||
except Exception as exc:
|
||||
await ctx.error(
|
||||
"Unexpected error creating dataset: %s: %s" % (type(exc).__name__, str(exc))
|
||||
)
|
||||
raise
|
||||
434
tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py
Normal file
434
tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py
Normal file
@@ -0,0 +1,434 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for create_dataset MCP tool."""
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_mock_dataset(
|
||||
dataset_id: int = 42,
|
||||
table_name: str = "orders",
|
||||
schema: str = "public",
|
||||
database_name: str = "main_db",
|
||||
) -> MagicMock:
|
||||
dataset = MagicMock()
|
||||
dataset.id = dataset_id
|
||||
dataset.table_name = table_name
|
||||
dataset.schema = schema
|
||||
dataset.description = None
|
||||
dataset.certified_by = None
|
||||
dataset.certification_details = None
|
||||
dataset.changed_by_name = "admin"
|
||||
dataset.changed_on = None
|
||||
dataset.changed_on_humanized = None
|
||||
dataset.created_by_name = "admin"
|
||||
dataset.created_on = None
|
||||
dataset.created_on_humanized = None
|
||||
dataset.tags = []
|
||||
dataset.owners = []
|
||||
dataset.is_virtual = False
|
||||
dataset.is_favorite = None
|
||||
dataset.database_id = 1
|
||||
dataset.schema_perm = f"[{database_name}].[{schema}]"
|
||||
dataset.url = f"/tablemodelview/edit/{dataset_id}"
|
||||
dataset.database = MagicMock()
|
||||
dataset.database.database_name = database_name
|
||||
dataset.sql = None
|
||||
dataset.main_dttm_col = None
|
||||
dataset.offset = 0
|
||||
dataset.cache_timeout = 0
|
||||
dataset.params = {}
|
||||
dataset.template_params = {}
|
||||
dataset.extra = {}
|
||||
dataset.uuid = f"dataset-uuid-{dataset_id}"
|
||||
dataset.columns = []
|
||||
dataset.metrics = []
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
class TestCreateDataset:
|
||||
"""Tests for the create_dataset MCP tool."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_database_access(self):
|
||||
"""Provide a mock database and allow all table access by default."""
|
||||
mock_db = MagicMock()
|
||||
with (
|
||||
patch(
|
||||
"superset.daos.database.DatabaseDAO.find_by_id", return_value=mock_db
|
||||
),
|
||||
patch("superset.extensions.security_manager.raise_for_access"),
|
||||
):
|
||||
yield mock_db
|
||||
|
||||
@patch("superset.commands.dataset.create.CreateDatasetCommand")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_success(self, mock_command_class, mcp_server):
|
||||
"""Happy path: tool creates dataset and returns DatasetInfo."""
|
||||
mock_dataset = _make_mock_dataset()
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 42
|
||||
assert data["table_name"] == "orders"
|
||||
assert data["schema"] == "public"
|
||||
|
||||
# Verify the command was called with the right properties
|
||||
call_kwargs = mock_command_class.call_args[0][0]
|
||||
assert call_kwargs["database"] == 1
|
||||
assert call_kwargs["schema"] == "public"
|
||||
assert call_kwargs["table_name"] == "orders"
|
||||
assert "owners" not in call_kwargs
|
||||
|
||||
@patch("superset.commands.dataset.create.CreateDatasetCommand")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_with_owners(self, mock_command_class, mcp_server):
|
||||
"""Owners list is forwarded to the command when supplied."""
|
||||
mock_dataset = _make_mock_dataset()
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 2,
|
||||
"schema": "sales",
|
||||
"table_name": "transactions",
|
||||
"owners": [5, 10],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 42
|
||||
|
||||
call_kwargs = mock_command_class.call_args[0][0]
|
||||
assert call_kwargs["owners"] == [5, 10]
|
||||
|
||||
@patch("superset.commands.dataset.create.CreateDatasetCommand")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_already_exists(self, mock_command_class, mcp_server):
|
||||
"""Returns DatasetExistsError when the table is already registered.
|
||||
|
||||
CreateDatasetCommand.validate() wraps DatasetExistsValidationError inside
|
||||
DatasetInvalidError. The tool must inspect get_list_classnames() to surface
|
||||
the typed error response.
|
||||
"""
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetExistsValidationError,
|
||||
DatasetInvalidError,
|
||||
)
|
||||
from superset.sql.parse import Table
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = DatasetInvalidError(
|
||||
exceptions=[DatasetExistsValidationError(Table("orders", "public", None))]
|
||||
)
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "DatasetExistsError"
|
||||
assert "error" in data
|
||||
|
||||
@patch("superset.commands.dataset.create.CreateDatasetCommand")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_table_not_found(self, mock_command_class, mcp_server):
|
||||
"""Returns TableNotFoundError when the physical table does not exist in the DB.
|
||||
|
||||
CreateDatasetCommand.validate() wraps TableNotFoundValidationError inside
|
||||
DatasetInvalidError. The tool must inspect get_list_classnames() to surface
|
||||
the typed error response.
|
||||
"""
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetInvalidError,
|
||||
TableNotFoundValidationError,
|
||||
)
|
||||
from superset.sql.parse import Table
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = DatasetInvalidError(
|
||||
exceptions=[
|
||||
TableNotFoundValidationError(Table("missing_table", "public", None))
|
||||
]
|
||||
)
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "missing_table",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "TableNotFoundError"
|
||||
|
||||
@patch("superset.commands.dataset.create.CreateDatasetCommand")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_with_catalog(self, mock_command_class, mcp_server):
|
||||
"""Catalog field is normalized and forwarded to the command when supplied."""
|
||||
mock_dataset = _make_mock_dataset()
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"catalog": " hive ",
|
||||
"schema": "default",
|
||||
"table_name": "events",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
call_kwargs = mock_command_class.call_args[0][0]
|
||||
assert call_kwargs["catalog"] == "hive"
|
||||
|
||||
@patch("superset.commands.dataset.create.CreateDatasetCommand")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_invalid_error(self, mock_command_class, mcp_server):
|
||||
"""DatasetInvalidError is returned as ValidationError type."""
|
||||
from superset.commands.dataset.exceptions import DatasetInvalidError
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = DatasetInvalidError()
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "error" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_database_not_found(self, mcp_server):
|
||||
"""Returns DatabaseNotFoundError when database_id does not exist."""
|
||||
with patch("superset.daos.database.DatabaseDAO.find_by_id", return_value=None):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 999,
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "DatabaseNotFoundError"
|
||||
assert "999" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_access_denied(self, mcp_server):
|
||||
"""Returns AccessDeniedError when caller lacks table-level access."""
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
mock_db = MagicMock()
|
||||
with (
|
||||
patch(
|
||||
"superset.daos.database.DatabaseDAO.find_by_id", return_value=mock_db
|
||||
),
|
||||
patch(
|
||||
"superset.extensions.security_manager.raise_for_access",
|
||||
side_effect=SupersetSecurityException(
|
||||
MagicMock(message="Access is Denied")
|
||||
),
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "secret",
|
||||
"table_name": "restricted_table",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "AccessDeniedError"
|
||||
|
||||
@patch("superset.commands.dataset.create.CreateDatasetCommand")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_unexpected_error(
|
||||
self, mock_command_class, mcp_server
|
||||
):
|
||||
"""Unexpected exceptions are re-raised as ToolError (handled by middleware)."""
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = RuntimeError("DB connection lost")
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError):
|
||||
await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_missing_required_fields(self, mcp_server):
|
||||
"""Missing required fields raise a validation error before the tool runs."""
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError):
|
||||
await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
# database_id and table_name are omitted intentionally
|
||||
"schema": "public",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@patch("superset.commands.dataset.create.CreateDatasetCommand")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_returns_full_dataset_info(
|
||||
self, mock_command_class, mcp_server
|
||||
):
|
||||
"""The returned DatasetInfo includes columns, metrics, and all core fields."""
|
||||
mock_dataset = _make_mock_dataset(
|
||||
dataset_id=99, table_name="sales", schema="dw"
|
||||
)
|
||||
|
||||
col = MagicMock()
|
||||
col.column_name = "amount"
|
||||
col.verbose_name = "Amount"
|
||||
col.type = "NUMERIC"
|
||||
col.is_dttm = False
|
||||
col.groupby = True
|
||||
col.filterable = True
|
||||
col.description = "Sale amount"
|
||||
mock_dataset.columns = [col]
|
||||
|
||||
metric = MagicMock()
|
||||
metric.metric_name = "total_sales"
|
||||
metric.verbose_name = "Total Sales"
|
||||
metric.expression = "SUM(amount)"
|
||||
metric.description = "Sum of amounts"
|
||||
metric.d3format = None
|
||||
mock_dataset.metrics = [metric]
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "dw",
|
||||
"table_name": "sales",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 99
|
||||
assert data["table_name"] == "sales"
|
||||
assert data["schema"] == "dw"
|
||||
assert data["is_virtual"] is False
|
||||
assert len(data["columns"]) == 1
|
||||
assert data["columns"][0]["column_name"] == "amount"
|
||||
assert len(data["metrics"]) == 1
|
||||
assert data["metrics"][0]["metric_name"] == "total_sales"
|
||||
Reference in New Issue
Block a user