mirror of
https://github.com/apache/superset.git
synced 2026-06-09 17:49:26 +00:00
Compare commits
8 Commits
fix/embedd
...
fix/column
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1d13e11dd | ||
|
|
caa3d975f6 | ||
|
|
4685d88243 | ||
|
|
9f554ffbb3 | ||
|
|
306120a277 | ||
|
|
d56c852051 | ||
|
|
c8e658026c | ||
|
|
552777a06a |
@@ -111,19 +111,19 @@ test('open with Simple tab selected when there is no column selected', () => {
|
||||
getCurrentTab: jest.fn(),
|
||||
onChange: jest.fn(),
|
||||
});
|
||||
expect(getByText('Saved')).toHaveAttribute('aria-selected', 'false');
|
||||
expect(getByText('Calculated')).toHaveAttribute('aria-selected', 'false');
|
||||
expect(getByText('Simple')).toHaveAttribute('aria-selected', 'true');
|
||||
expect(getByText('Custom SQL')).toHaveAttribute('aria-selected', 'false');
|
||||
});
|
||||
|
||||
test('open with Saved tab selected when there is a saved column selected', () => {
|
||||
test('open with Calculated tab selected when there is a saved column selected', () => {
|
||||
const { getByText } = renderPopover({
|
||||
columns: [{ column_name: 'year' }],
|
||||
editedColumn: { column_name: 'year', expression: 'year - 1' },
|
||||
getCurrentTab: jest.fn(),
|
||||
onChange: jest.fn(),
|
||||
});
|
||||
expect(getByText('Saved')).toHaveAttribute('aria-selected', 'true');
|
||||
expect(getByText('Calculated')).toHaveAttribute('aria-selected', 'true');
|
||||
expect(getByText('Simple')).toHaveAttribute('aria-selected', 'false');
|
||||
expect(getByText('Custom SQL')).toHaveAttribute('aria-selected', 'false');
|
||||
});
|
||||
@@ -139,7 +139,7 @@ test('open with Custom SQL tab selected when there is a custom SQL selected', ()
|
||||
getCurrentTab: jest.fn(),
|
||||
onChange: jest.fn(),
|
||||
});
|
||||
expect(getByText('Saved')).toHaveAttribute('aria-selected', 'false');
|
||||
expect(getByText('Calculated')).toHaveAttribute('aria-selected', 'false');
|
||||
expect(getByText('Simple')).toHaveAttribute('aria-selected', 'false');
|
||||
expect(getByText('Custom SQL')).toHaveAttribute('aria-selected', 'true');
|
||||
});
|
||||
@@ -283,7 +283,7 @@ test('Should filter saved expressions by column_name and verbose_name', async ()
|
||||
fireEvent.click(savedTab!);
|
||||
|
||||
const combobox = screen.getByRole('combobox', {
|
||||
name: 'Saved expressions',
|
||||
name: 'Calculated columns',
|
||||
});
|
||||
|
||||
await userEvent.type(combobox, 'revenue');
|
||||
|
||||
@@ -382,7 +382,7 @@ const ColumnSelectPopover = ({
|
||||
selectedMetric?.metric_name !== undefined ||
|
||||
adhocColumn?.sqlExpression !== initialAdhocColumn?.sqlExpression;
|
||||
|
||||
const savedExpressionsLabel = t('Saved expressions');
|
||||
const savedExpressionsLabel = t('Calculated columns');
|
||||
const simpleColumnsLabel = t('Columns and metrics');
|
||||
const keywords = useMemo(
|
||||
() => sqlKeywords.concat(getColumnKeywords(columns)),
|
||||
@@ -408,7 +408,7 @@ const ColumnSelectPopover = ({
|
||||
: [
|
||||
{
|
||||
key: TABS_KEYS.SAVED,
|
||||
label: t('Saved'),
|
||||
label: t('Calculated'),
|
||||
children: (
|
||||
<>
|
||||
{calculatedColumns.length > 0 ? (
|
||||
@@ -448,7 +448,7 @@ const ColumnSelectPopover = ({
|
||||
title={
|
||||
isTemporal
|
||||
? t('No temporal columns found')
|
||||
: t('No saved expressions found')
|
||||
: t('No calculated columns found')
|
||||
}
|
||||
description={
|
||||
isTemporal
|
||||
@@ -467,7 +467,7 @@ const ColumnSelectPopover = ({
|
||||
title={
|
||||
isTemporal
|
||||
? t('No temporal columns found')
|
||||
: t('No saved expressions found')
|
||||
: t('No calculated columns found')
|
||||
}
|
||||
description={
|
||||
isTemporal ? (
|
||||
|
||||
@@ -76,6 +76,7 @@ Database Connections:
|
||||
Dataset Management:
|
||||
- list_datasets: List datasets with advanced filters (1-based pagination)
|
||||
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
|
||||
- create_dataset: Register an existing physical table as a dataset against a DB connection
|
||||
- create_virtual_dataset: Save a SQL query as a virtual dataset for charting
|
||||
- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart
|
||||
|
||||
@@ -544,6 +545,7 @@ from superset.mcp_service.database.tool import ( # noqa: F401, E402
|
||||
list_databases,
|
||||
)
|
||||
from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||
create_dataset,
|
||||
create_virtual_dataset,
|
||||
get_dataset_info,
|
||||
list_datasets,
|
||||
|
||||
@@ -323,6 +323,52 @@ class GetDatasetInfoRequest(MetadataCacheControl):
|
||||
]
|
||||
|
||||
|
||||
class CreateDatasetRequest(BaseModel):
|
||||
"""Request schema for create_dataset to register a physical table as a dataset."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
database_id: Annotated[
|
||||
int,
|
||||
Field(
|
||||
description="ID of the database connection to register the table against"
|
||||
),
|
||||
]
|
||||
schema_name: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
alias="schema",
|
||||
description="Schema where the table lives (optional).",
|
||||
),
|
||||
]
|
||||
table_name: Annotated[
|
||||
str,
|
||||
Field(description="Name of the physical table to register as a dataset"),
|
||||
]
|
||||
catalog: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Catalog where the table lives (optional).",
|
||||
),
|
||||
]
|
||||
owners: Annotated[
|
||||
List[int] | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Optional list of owner user IDs. Defaults to calling user.",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("table_name")
|
||||
@classmethod
|
||||
def table_name_must_not_be_empty(cls, v: str) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError("table_name must not be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class CreateVirtualDatasetRequest(BaseModel):
|
||||
"""Request schema for create_virtual_dataset."""
|
||||
|
||||
|
||||
@@ -15,12 +15,14 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .create_dataset import create_dataset
|
||||
from .create_virtual_dataset import create_virtual_dataset
|
||||
from .get_dataset_info import get_dataset_info
|
||||
from .list_datasets import list_datasets
|
||||
from .query_dataset import query_dataset
|
||||
|
||||
__all__ = [
|
||||
"create_dataset",
|
||||
"create_virtual_dataset",
|
||||
"list_datasets",
|
||||
"get_dataset_info",
|
||||
|
||||
145
superset/mcp_service/dataset/tool/create_dataset.py
Normal file
145
superset/mcp_service/dataset/tool/create_dataset.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.extensions import event_logger, security_manager
|
||||
from superset.mcp_service.dataset.schemas import (
|
||||
CreateDatasetRequest,
|
||||
DatasetError,
|
||||
DatasetInfo,
|
||||
serialize_dataset_object,
|
||||
)
|
||||
from superset.sql.parse import Table
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["mutate"],
|
||||
class_permission_name="Dataset",
|
||||
method_permission_name="write",
|
||||
annotations=ToolAnnotations(
|
||||
title="Register a physical table as a dataset",
|
||||
readOnlyHint=False,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def create_dataset(
|
||||
request: CreateDatasetRequest, ctx: Context
|
||||
) -> DatasetInfo | DatasetError:
|
||||
"""Register an existing physical table as a Superset dataset.
|
||||
|
||||
Use this tool when the user wants to make a physical database table available
|
||||
for charting or exploration. The table must already exist in the target database.
|
||||
|
||||
Workflow:
|
||||
1. Call list_databases to find the correct database_id
|
||||
2. Call this tool with database_id, schema, and table_name
|
||||
3. Use the returned id as dataset_id in generate_chart or generate_explore_link
|
||||
|
||||
Returns DatasetInfo on success or DatasetError with error_type on failure.
|
||||
"""
|
||||
await ctx.info(
|
||||
"Registering physical table as dataset: database_id=%s, schema=%r, table=%r"
|
||||
% (request.database_id, request.schema_name, request.table_name)
|
||||
)
|
||||
|
||||
# Verify the database exists and the caller has table-level access before
|
||||
# registering. Mirrors the check in DatabaseRestApi.table_metadata().
|
||||
database = DatasetDAO.get_database_by_id(request.database_id)
|
||||
if database is None:
|
||||
await ctx.warning("Database %s not found" % request.database_id)
|
||||
return DatasetError.create(
|
||||
error=f"Database {request.database_id} not found",
|
||||
error_type="DatabaseNotFoundError",
|
||||
)
|
||||
|
||||
table = Table(request.table_name, request.schema_name, request.catalog)
|
||||
try:
|
||||
security_manager.raise_for_access(database=database, table=table)
|
||||
except SupersetSecurityException as exc:
|
||||
await ctx.warning("Access denied for table %r: %s" % (str(table), str(exc)))
|
||||
return DatasetError.create(error=str(exc), error_type="AccessDeniedError")
|
||||
|
||||
try:
|
||||
from superset.commands.dataset.create import CreateDatasetCommand
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetCreateFailedError,
|
||||
DatasetExistsValidationError,
|
||||
DatasetInvalidError,
|
||||
TableNotFoundValidationError,
|
||||
)
|
||||
|
||||
dataset_properties: dict[str, Any] = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"database": request.database_id,
|
||||
"table_name": request.table_name,
|
||||
"schema": request.schema_name,
|
||||
"catalog": request.catalog,
|
||||
"owners": request.owners,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
with event_logger.log_context(action="mcp.create_dataset.create"):
|
||||
dataset = CreateDatasetCommand(dataset_properties).run()
|
||||
|
||||
result = serialize_dataset_object(dataset)
|
||||
if result is None:
|
||||
return DatasetError.create(
|
||||
error="Dataset was created but could not be serialized",
|
||||
error_type="InternalError",
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Dataset registered: id=%s, table=%r" % (dataset.id, dataset.table_name)
|
||||
)
|
||||
return result
|
||||
|
||||
except DatasetInvalidError as exc:
|
||||
# CreateDatasetCommand.validate() aggregates individual validation errors
|
||||
# into DatasetInvalidError; use the public get_list_classnames() helper
|
||||
# to identify which specific validation errors are present.
|
||||
classnames = exc.get_list_classnames()
|
||||
if DatasetExistsValidationError.__name__ in classnames:
|
||||
await ctx.warning("Dataset already exists: %s" % str(exc))
|
||||
return DatasetError.create(error=str(exc), error_type="DatasetExistsError")
|
||||
if TableNotFoundValidationError.__name__ in classnames:
|
||||
await ctx.warning("Table not found: %s" % str(exc))
|
||||
return DatasetError.create(error=str(exc), error_type="TableNotFoundError")
|
||||
messages = exc.normalized_messages()
|
||||
await ctx.warning("Dataset validation failed: %s" % (messages,))
|
||||
return DatasetError.create(error=str(messages), error_type="ValidationError")
|
||||
except DatasetCreateFailedError as exc:
|
||||
await ctx.error("Dataset creation failed: %s" % str(exc))
|
||||
return DatasetError.create(error=str(exc), error_type="CreateFailedError")
|
||||
except Exception as exc:
|
||||
await ctx.error(
|
||||
"Unexpected error registering dataset: %s: %s"
|
||||
% (type(exc).__name__, str(exc))
|
||||
)
|
||||
return DatasetError.create(
|
||||
error=f"Failed to create dataset: {exc}", error_type="InternalError"
|
||||
)
|
||||
438
tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py
Normal file
438
tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for create_dataset MCP tool."""
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patch at source so lazy imports inside the tool function are intercepted.
|
||||
_CMD_PATH = "superset.commands.dataset.create.CreateDatasetCommand"
|
||||
_DAO_PATH = "superset.mcp_service.dataset.tool.create_dataset.DatasetDAO"
|
||||
_SEC_PATH = "superset.mcp_service.dataset.tool.create_dataset.security_manager"
|
||||
|
||||
|
||||
def _make_mock_dataset(
|
||||
dataset_id: int = 42,
|
||||
table_name: str = "orders",
|
||||
schema: str = "public",
|
||||
database_name: str = "main_db",
|
||||
) -> MagicMock:
|
||||
dataset = MagicMock()
|
||||
dataset.id = dataset_id
|
||||
dataset.table_name = table_name
|
||||
dataset.schema = schema
|
||||
dataset.description = None
|
||||
dataset.certified_by = None
|
||||
dataset.certification_details = None
|
||||
dataset.changed_by = None
|
||||
dataset.changed_on = None
|
||||
dataset.changed_on_humanized = None
|
||||
dataset.created_by = None
|
||||
dataset.created_on = None
|
||||
dataset.created_on_humanized = None
|
||||
dataset.tags = []
|
||||
dataset.owners = []
|
||||
dataset.is_virtual = False
|
||||
dataset.database_id = 1
|
||||
dataset.schema_perm = f"[{database_name}].[{schema}]"
|
||||
dataset.database = MagicMock()
|
||||
dataset.database.database_name = database_name
|
||||
dataset.sql = None
|
||||
dataset.main_dttm_col = None
|
||||
dataset.offset = 0
|
||||
dataset.cache_timeout = 0
|
||||
dataset.params = None
|
||||
dataset.template_params = None
|
||||
dataset.extra = None
|
||||
dataset.uuid = f"dataset-uuid-{dataset_id}"
|
||||
dataset.columns = []
|
||||
dataset.metrics = []
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
class TestCreateDataset:
|
||||
"""Tests for the create_dataset MCP tool."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_dao_and_security(self):
|
||||
"""Default: valid database exists and access is granted.
|
||||
|
||||
Patches the pre-command access check so individual tests that only care
|
||||
about command behavior don't need to replicate this setup.
|
||||
"""
|
||||
with patch(_DAO_PATH) as mock_dao, patch(_SEC_PATH) as mock_sec:
|
||||
mock_dao.get_database_by_id.return_value = MagicMock(
|
||||
id=1, database_name="test_db"
|
||||
)
|
||||
yield mock_dao, mock_sec
|
||||
|
||||
@patch(_CMD_PATH)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_success(self, mock_command_class, mcp_server):
|
||||
"""Happy path: tool creates dataset and returns DatasetInfo."""
|
||||
mock_dataset = _make_mock_dataset()
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 42
|
||||
assert data["table_name"] == "orders"
|
||||
assert data["schema"] == "public"
|
||||
|
||||
call_kwargs = mock_command_class.call_args[0][0]
|
||||
assert call_kwargs["database"] == 1
|
||||
assert call_kwargs["schema"] == "public"
|
||||
assert call_kwargs["table_name"] == "orders"
|
||||
assert "owners" not in call_kwargs
|
||||
|
||||
@patch(_CMD_PATH)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_with_owners(self, mock_command_class, mcp_server):
|
||||
"""Owners list is forwarded to the command when supplied."""
|
||||
mock_dataset = _make_mock_dataset()
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 2,
|
||||
"schema": "sales",
|
||||
"table_name": "transactions",
|
||||
"owners": [5, 10],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 42
|
||||
|
||||
call_kwargs = mock_command_class.call_args[0][0]
|
||||
assert call_kwargs["owners"] == [5, 10]
|
||||
|
||||
@patch(_CMD_PATH)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_already_exists(self, mock_command_class, mcp_server):
|
||||
"""Returns DatasetExistsError when a dataset for the table already exists.
|
||||
|
||||
CreateDatasetCommand.validate() wraps DatasetExistsValidationError inside
|
||||
DatasetInvalidError, so simulate the real command shape.
|
||||
"""
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetExistsValidationError,
|
||||
DatasetInvalidError,
|
||||
)
|
||||
from superset.sql.parse import Table
|
||||
|
||||
exc = DatasetInvalidError()
|
||||
exc.append(DatasetExistsValidationError(Table("orders", "public", None)))
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = exc
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "DatasetExistsError"
|
||||
assert "error" in data
|
||||
|
||||
@patch(_CMD_PATH)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_table_not_found(self, mock_command_class, mcp_server):
|
||||
"""Returns TableNotFoundError when the physical table does not exist in the DB.
|
||||
|
||||
CreateDatasetCommand.validate() wraps TableNotFoundValidationError inside
|
||||
DatasetInvalidError, so simulate the real command shape.
|
||||
"""
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetInvalidError,
|
||||
TableNotFoundValidationError,
|
||||
)
|
||||
from superset.sql.parse import Table
|
||||
|
||||
exc = DatasetInvalidError()
|
||||
exc.append(TableNotFoundValidationError(Table("missing_table", "public", None)))
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = exc
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "missing_table",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "TableNotFoundError"
|
||||
|
||||
@patch(_CMD_PATH)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_unexpected_error(
|
||||
self, mock_command_class, mcp_server
|
||||
):
|
||||
"""Unexpected exceptions are caught and returned as InternalError."""
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = RuntimeError("DB connection lost")
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "InternalError"
|
||||
assert "DB connection lost" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_missing_required_fields(self, mcp_server):
|
||||
"""Missing required fields raise a validation error before the tool runs."""
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError):
|
||||
await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
# database_id and table_name are omitted intentionally
|
||||
"schema": "public",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@patch(_CMD_PATH)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_returns_full_dataset_info(
|
||||
self, mock_command_class, mcp_server
|
||||
):
|
||||
"""The returned DatasetInfo includes columns, metrics, and all core fields."""
|
||||
mock_dataset = _make_mock_dataset(
|
||||
dataset_id=99, table_name="sales", schema="dw"
|
||||
)
|
||||
|
||||
col = MagicMock()
|
||||
col.column_name = "amount"
|
||||
col.verbose_name = "Amount"
|
||||
col.type = "NUMERIC"
|
||||
col.is_dttm = False
|
||||
col.groupby = True
|
||||
col.filterable = True
|
||||
col.description = "Sale amount"
|
||||
mock_dataset.columns = [col]
|
||||
|
||||
metric = MagicMock()
|
||||
metric.metric_name = "total_sales"
|
||||
metric.verbose_name = "Total Sales"
|
||||
metric.expression = "SUM(amount)"
|
||||
metric.description = "Sum of amounts"
|
||||
metric.d3format = None
|
||||
mock_dataset.metrics = [metric]
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "dw",
|
||||
"table_name": "sales",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 99
|
||||
assert data["table_name"] == "sales"
|
||||
assert data["schema"] == "dw"
|
||||
assert data["is_virtual"] is False
|
||||
assert len(data["columns"]) == 1
|
||||
assert data["columns"][0]["column_name"] == "amount"
|
||||
assert len(data["metrics"]) == 1
|
||||
assert data["metrics"][0]["metric_name"] == "total_sales"
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_database_not_found(
|
||||
self, mock_dao_and_security, mcp_server
|
||||
):
|
||||
"""Returns DatabaseNotFoundError when the database_id does not exist."""
|
||||
mock_dao, _ = mock_dao_and_security
|
||||
mock_dao.get_database_by_id.return_value = None
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{"request": {"database_id": 999, "table_name": "orders"}},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "DatabaseNotFoundError"
|
||||
assert "999" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_access_denied(
|
||||
self, mock_dao_and_security, mcp_server
|
||||
):
|
||||
"""Returns AccessDeniedError when the caller lacks table-level access."""
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
access_exc = SupersetSecurityException(
|
||||
SupersetError(
|
||||
message="Access denied",
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
# Patch _SEC_PATH explicitly inside the test with side_effect pre-configured
|
||||
# so the raise_for_access mock is guaranteed to raise during the tool call.
|
||||
# The autouse mock_dao_and_security fixture keeps the DAO mock active (database
|
||||
# found), and this inner patch overrides the security manager mock only.
|
||||
with patch(_SEC_PATH) as mock_sec_override:
|
||||
mock_sec_override.raise_for_access.side_effect = access_exc
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{"request": {"database_id": 1, "table_name": "secret_table"}},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "AccessDeniedError"
|
||||
|
||||
@patch(_CMD_PATH)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_no_schema(
|
||||
self, mock_command_class, mock_dao_and_security, mcp_server
|
||||
):
|
||||
"""schema is optional; omitting it does not pass it to the command."""
|
||||
mock_dataset = _make_mock_dataset()
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{"request": {"database_id": 1, "table_name": "orders"}},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 42
|
||||
|
||||
call_kwargs = mock_command_class.call_args[0][0]
|
||||
assert "schema" not in call_kwargs
|
||||
|
||||
@patch(_CMD_PATH)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_create_dataset_with_catalog(
|
||||
self, mock_command_class, mock_dao_and_security, mcp_server
|
||||
):
|
||||
"""catalog is forwarded to CreateDatasetCommand when provided."""
|
||||
mock_dataset = _make_mock_dataset()
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
mock_command_class.return_value = mock_command
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"table_name": "orders",
|
||||
"catalog": "prod_catalog",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 42
|
||||
|
||||
call_kwargs = mock_command_class.call_args[0][0]
|
||||
assert call_kwargs["catalog"] == "prod_catalog"
|
||||
assert "schema" not in call_kwargs
|
||||
Reference in New Issue
Block a user