mirror of
https://github.com/apache/superset.git
synced 2026-06-28 19:05:31 +00:00
Compare commits
8 Commits
chore/ci/s
...
fix/column
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1d13e11dd | ||
|
|
caa3d975f6 | ||
|
|
4685d88243 | ||
|
|
9f554ffbb3 | ||
|
|
306120a277 | ||
|
|
d56c852051 | ||
|
|
c8e658026c | ||
|
|
552777a06a |
@@ -111,19 +111,19 @@ test('open with Simple tab selected when there is no column selected', () => {
|
|||||||
getCurrentTab: jest.fn(),
|
getCurrentTab: jest.fn(),
|
||||||
onChange: jest.fn(),
|
onChange: jest.fn(),
|
||||||
});
|
});
|
||||||
expect(getByText('Saved')).toHaveAttribute('aria-selected', 'false');
|
expect(getByText('Calculated')).toHaveAttribute('aria-selected', 'false');
|
||||||
expect(getByText('Simple')).toHaveAttribute('aria-selected', 'true');
|
expect(getByText('Simple')).toHaveAttribute('aria-selected', 'true');
|
||||||
expect(getByText('Custom SQL')).toHaveAttribute('aria-selected', 'false');
|
expect(getByText('Custom SQL')).toHaveAttribute('aria-selected', 'false');
|
||||||
});
|
});
|
||||||
|
|
||||||
test('open with Saved tab selected when there is a saved column selected', () => {
|
test('open with Calculated tab selected when there is a saved column selected', () => {
|
||||||
const { getByText } = renderPopover({
|
const { getByText } = renderPopover({
|
||||||
columns: [{ column_name: 'year' }],
|
columns: [{ column_name: 'year' }],
|
||||||
editedColumn: { column_name: 'year', expression: 'year - 1' },
|
editedColumn: { column_name: 'year', expression: 'year - 1' },
|
||||||
getCurrentTab: jest.fn(),
|
getCurrentTab: jest.fn(),
|
||||||
onChange: jest.fn(),
|
onChange: jest.fn(),
|
||||||
});
|
});
|
||||||
expect(getByText('Saved')).toHaveAttribute('aria-selected', 'true');
|
expect(getByText('Calculated')).toHaveAttribute('aria-selected', 'true');
|
||||||
expect(getByText('Simple')).toHaveAttribute('aria-selected', 'false');
|
expect(getByText('Simple')).toHaveAttribute('aria-selected', 'false');
|
||||||
expect(getByText('Custom SQL')).toHaveAttribute('aria-selected', 'false');
|
expect(getByText('Custom SQL')).toHaveAttribute('aria-selected', 'false');
|
||||||
});
|
});
|
||||||
@@ -139,7 +139,7 @@ test('open with Custom SQL tab selected when there is a custom SQL selected', ()
|
|||||||
getCurrentTab: jest.fn(),
|
getCurrentTab: jest.fn(),
|
||||||
onChange: jest.fn(),
|
onChange: jest.fn(),
|
||||||
});
|
});
|
||||||
expect(getByText('Saved')).toHaveAttribute('aria-selected', 'false');
|
expect(getByText('Calculated')).toHaveAttribute('aria-selected', 'false');
|
||||||
expect(getByText('Simple')).toHaveAttribute('aria-selected', 'false');
|
expect(getByText('Simple')).toHaveAttribute('aria-selected', 'false');
|
||||||
expect(getByText('Custom SQL')).toHaveAttribute('aria-selected', 'true');
|
expect(getByText('Custom SQL')).toHaveAttribute('aria-selected', 'true');
|
||||||
});
|
});
|
||||||
@@ -283,7 +283,7 @@ test('Should filter saved expressions by column_name and verbose_name', async ()
|
|||||||
fireEvent.click(savedTab!);
|
fireEvent.click(savedTab!);
|
||||||
|
|
||||||
const combobox = screen.getByRole('combobox', {
|
const combobox = screen.getByRole('combobox', {
|
||||||
name: 'Saved expressions',
|
name: 'Calculated columns',
|
||||||
});
|
});
|
||||||
|
|
||||||
await userEvent.type(combobox, 'revenue');
|
await userEvent.type(combobox, 'revenue');
|
||||||
|
|||||||
@@ -382,7 +382,7 @@ const ColumnSelectPopover = ({
|
|||||||
selectedMetric?.metric_name !== undefined ||
|
selectedMetric?.metric_name !== undefined ||
|
||||||
adhocColumn?.sqlExpression !== initialAdhocColumn?.sqlExpression;
|
adhocColumn?.sqlExpression !== initialAdhocColumn?.sqlExpression;
|
||||||
|
|
||||||
const savedExpressionsLabel = t('Saved expressions');
|
const savedExpressionsLabel = t('Calculated columns');
|
||||||
const simpleColumnsLabel = t('Columns and metrics');
|
const simpleColumnsLabel = t('Columns and metrics');
|
||||||
const keywords = useMemo(
|
const keywords = useMemo(
|
||||||
() => sqlKeywords.concat(getColumnKeywords(columns)),
|
() => sqlKeywords.concat(getColumnKeywords(columns)),
|
||||||
@@ -408,7 +408,7 @@ const ColumnSelectPopover = ({
|
|||||||
: [
|
: [
|
||||||
{
|
{
|
||||||
key: TABS_KEYS.SAVED,
|
key: TABS_KEYS.SAVED,
|
||||||
label: t('Saved'),
|
label: t('Calculated'),
|
||||||
children: (
|
children: (
|
||||||
<>
|
<>
|
||||||
{calculatedColumns.length > 0 ? (
|
{calculatedColumns.length > 0 ? (
|
||||||
@@ -448,7 +448,7 @@ const ColumnSelectPopover = ({
|
|||||||
title={
|
title={
|
||||||
isTemporal
|
isTemporal
|
||||||
? t('No temporal columns found')
|
? t('No temporal columns found')
|
||||||
: t('No saved expressions found')
|
: t('No calculated columns found')
|
||||||
}
|
}
|
||||||
description={
|
description={
|
||||||
isTemporal
|
isTemporal
|
||||||
@@ -467,7 +467,7 @@ const ColumnSelectPopover = ({
|
|||||||
title={
|
title={
|
||||||
isTemporal
|
isTemporal
|
||||||
? t('No temporal columns found')
|
? t('No temporal columns found')
|
||||||
: t('No saved expressions found')
|
: t('No calculated columns found')
|
||||||
}
|
}
|
||||||
description={
|
description={
|
||||||
isTemporal ? (
|
isTemporal ? (
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ Database Connections:
|
|||||||
Dataset Management:
|
Dataset Management:
|
||||||
- list_datasets: List datasets with advanced filters (1-based pagination)
|
- list_datasets: List datasets with advanced filters (1-based pagination)
|
||||||
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
|
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
|
||||||
|
- create_dataset: Register an existing physical table as a dataset against a DB connection
|
||||||
- create_virtual_dataset: Save a SQL query as a virtual dataset for charting
|
- create_virtual_dataset: Save a SQL query as a virtual dataset for charting
|
||||||
- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart
|
- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart
|
||||||
|
|
||||||
@@ -544,6 +545,7 @@ from superset.mcp_service.database.tool import ( # noqa: F401, E402
|
|||||||
list_databases,
|
list_databases,
|
||||||
)
|
)
|
||||||
from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||||
|
create_dataset,
|
||||||
create_virtual_dataset,
|
create_virtual_dataset,
|
||||||
get_dataset_info,
|
get_dataset_info,
|
||||||
list_datasets,
|
list_datasets,
|
||||||
|
|||||||
@@ -323,6 +323,52 @@ class GetDatasetInfoRequest(MetadataCacheControl):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CreateDatasetRequest(BaseModel):
|
||||||
|
"""Request schema for create_dataset to register a physical table as a dataset."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
database_id: Annotated[
|
||||||
|
int,
|
||||||
|
Field(
|
||||||
|
description="ID of the database connection to register the table against"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
schema_name: Annotated[
|
||||||
|
str | None,
|
||||||
|
Field(
|
||||||
|
default=None,
|
||||||
|
alias="schema",
|
||||||
|
description="Schema where the table lives (optional).",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
table_name: Annotated[
|
||||||
|
str,
|
||||||
|
Field(description="Name of the physical table to register as a dataset"),
|
||||||
|
]
|
||||||
|
catalog: Annotated[
|
||||||
|
str | None,
|
||||||
|
Field(
|
||||||
|
default=None,
|
||||||
|
description="Catalog where the table lives (optional).",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
owners: Annotated[
|
||||||
|
List[int] | None,
|
||||||
|
Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional list of owner user IDs. Defaults to calling user.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@field_validator("table_name")
|
||||||
|
@classmethod
|
||||||
|
def table_name_must_not_be_empty(cls, v: str) -> str:
|
||||||
|
if not v.strip():
|
||||||
|
raise ValueError("table_name must not be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
class CreateVirtualDatasetRequest(BaseModel):
|
class CreateVirtualDatasetRequest(BaseModel):
|
||||||
"""Request schema for create_virtual_dataset."""
|
"""Request schema for create_virtual_dataset."""
|
||||||
|
|
||||||
|
|||||||
@@ -15,12 +15,14 @@
|
|||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
from .create_dataset import create_dataset
|
||||||
from .create_virtual_dataset import create_virtual_dataset
|
from .create_virtual_dataset import create_virtual_dataset
|
||||||
from .get_dataset_info import get_dataset_info
|
from .get_dataset_info import get_dataset_info
|
||||||
from .list_datasets import list_datasets
|
from .list_datasets import list_datasets
|
||||||
from .query_dataset import query_dataset
|
from .query_dataset import query_dataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"create_dataset",
|
||||||
"create_virtual_dataset",
|
"create_virtual_dataset",
|
||||||
"list_datasets",
|
"list_datasets",
|
||||||
"get_dataset_info",
|
"get_dataset_info",
|
||||||
|
|||||||
145
superset/mcp_service/dataset/tool/create_dataset.py
Normal file
145
superset/mcp_service/dataset/tool/create_dataset.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastmcp import Context
|
||||||
|
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||||
|
|
||||||
|
from superset.daos.dataset import DatasetDAO
|
||||||
|
from superset.exceptions import SupersetSecurityException
|
||||||
|
from superset.extensions import event_logger, security_manager
|
||||||
|
from superset.mcp_service.dataset.schemas import (
|
||||||
|
CreateDatasetRequest,
|
||||||
|
DatasetError,
|
||||||
|
DatasetInfo,
|
||||||
|
serialize_dataset_object,
|
||||||
|
)
|
||||||
|
from superset.sql.parse import Table
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
tags=["mutate"],
|
||||||
|
class_permission_name="Dataset",
|
||||||
|
method_permission_name="write",
|
||||||
|
annotations=ToolAnnotations(
|
||||||
|
title="Register a physical table as a dataset",
|
||||||
|
readOnlyHint=False,
|
||||||
|
destructiveHint=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def create_dataset(
|
||||||
|
request: CreateDatasetRequest, ctx: Context
|
||||||
|
) -> DatasetInfo | DatasetError:
|
||||||
|
"""Register an existing physical table as a Superset dataset.
|
||||||
|
|
||||||
|
Use this tool when the user wants to make a physical database table available
|
||||||
|
for charting or exploration. The table must already exist in the target database.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. Call list_databases to find the correct database_id
|
||||||
|
2. Call this tool with database_id, schema, and table_name
|
||||||
|
3. Use the returned id as dataset_id in generate_chart or generate_explore_link
|
||||||
|
|
||||||
|
Returns DatasetInfo on success or DatasetError with error_type on failure.
|
||||||
|
"""
|
||||||
|
await ctx.info(
|
||||||
|
"Registering physical table as dataset: database_id=%s, schema=%r, table=%r"
|
||||||
|
% (request.database_id, request.schema_name, request.table_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the database exists and the caller has table-level access before
|
||||||
|
# registering. Mirrors the check in DatabaseRestApi.table_metadata().
|
||||||
|
database = DatasetDAO.get_database_by_id(request.database_id)
|
||||||
|
if database is None:
|
||||||
|
await ctx.warning("Database %s not found" % request.database_id)
|
||||||
|
return DatasetError.create(
|
||||||
|
error=f"Database {request.database_id} not found",
|
||||||
|
error_type="DatabaseNotFoundError",
|
||||||
|
)
|
||||||
|
|
||||||
|
table = Table(request.table_name, request.schema_name, request.catalog)
|
||||||
|
try:
|
||||||
|
security_manager.raise_for_access(database=database, table=table)
|
||||||
|
except SupersetSecurityException as exc:
|
||||||
|
await ctx.warning("Access denied for table %r: %s" % (str(table), str(exc)))
|
||||||
|
return DatasetError.create(error=str(exc), error_type="AccessDeniedError")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from superset.commands.dataset.create import CreateDatasetCommand
|
||||||
|
from superset.commands.dataset.exceptions import (
|
||||||
|
DatasetCreateFailedError,
|
||||||
|
DatasetExistsValidationError,
|
||||||
|
DatasetInvalidError,
|
||||||
|
TableNotFoundValidationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_properties: dict[str, Any] = {
|
||||||
|
k: v
|
||||||
|
for k, v in {
|
||||||
|
"database": request.database_id,
|
||||||
|
"table_name": request.table_name,
|
||||||
|
"schema": request.schema_name,
|
||||||
|
"catalog": request.catalog,
|
||||||
|
"owners": request.owners,
|
||||||
|
}.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
with event_logger.log_context(action="mcp.create_dataset.create"):
|
||||||
|
dataset = CreateDatasetCommand(dataset_properties).run()
|
||||||
|
|
||||||
|
result = serialize_dataset_object(dataset)
|
||||||
|
if result is None:
|
||||||
|
return DatasetError.create(
|
||||||
|
error="Dataset was created but could not be serialized",
|
||||||
|
error_type="InternalError",
|
||||||
|
)
|
||||||
|
|
||||||
|
await ctx.info(
|
||||||
|
"Dataset registered: id=%s, table=%r" % (dataset.id, dataset.table_name)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except DatasetInvalidError as exc:
|
||||||
|
# CreateDatasetCommand.validate() aggregates individual validation errors
|
||||||
|
# into DatasetInvalidError; use the public get_list_classnames() helper
|
||||||
|
# to identify which specific validation errors are present.
|
||||||
|
classnames = exc.get_list_classnames()
|
||||||
|
if DatasetExistsValidationError.__name__ in classnames:
|
||||||
|
await ctx.warning("Dataset already exists: %s" % str(exc))
|
||||||
|
return DatasetError.create(error=str(exc), error_type="DatasetExistsError")
|
||||||
|
if TableNotFoundValidationError.__name__ in classnames:
|
||||||
|
await ctx.warning("Table not found: %s" % str(exc))
|
||||||
|
return DatasetError.create(error=str(exc), error_type="TableNotFoundError")
|
||||||
|
messages = exc.normalized_messages()
|
||||||
|
await ctx.warning("Dataset validation failed: %s" % (messages,))
|
||||||
|
return DatasetError.create(error=str(messages), error_type="ValidationError")
|
||||||
|
except DatasetCreateFailedError as exc:
|
||||||
|
await ctx.error("Dataset creation failed: %s" % str(exc))
|
||||||
|
return DatasetError.create(error=str(exc), error_type="CreateFailedError")
|
||||||
|
except Exception as exc:
|
||||||
|
await ctx.error(
|
||||||
|
"Unexpected error registering dataset: %s: %s"
|
||||||
|
% (type(exc).__name__, str(exc))
|
||||||
|
)
|
||||||
|
return DatasetError.create(
|
||||||
|
error=f"Failed to create dataset: {exc}", error_type="InternalError"
|
||||||
|
)
|
||||||
438
tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py
Normal file
438
tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""Unit tests for create_dataset MCP tool."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastmcp import Client
|
||||||
|
from fastmcp.exceptions import ToolError
|
||||||
|
|
||||||
|
from superset.mcp_service.app import mcp
|
||||||
|
from superset.utils import json
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Patch at source so lazy imports inside the tool function are intercepted.
|
||||||
|
_CMD_PATH = "superset.commands.dataset.create.CreateDatasetCommand"
|
||||||
|
_DAO_PATH = "superset.mcp_service.dataset.tool.create_dataset.DatasetDAO"
|
||||||
|
_SEC_PATH = "superset.mcp_service.dataset.tool.create_dataset.security_manager"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_dataset(
|
||||||
|
dataset_id: int = 42,
|
||||||
|
table_name: str = "orders",
|
||||||
|
schema: str = "public",
|
||||||
|
database_name: str = "main_db",
|
||||||
|
) -> MagicMock:
|
||||||
|
dataset = MagicMock()
|
||||||
|
dataset.id = dataset_id
|
||||||
|
dataset.table_name = table_name
|
||||||
|
dataset.schema = schema
|
||||||
|
dataset.description = None
|
||||||
|
dataset.certified_by = None
|
||||||
|
dataset.certification_details = None
|
||||||
|
dataset.changed_by = None
|
||||||
|
dataset.changed_on = None
|
||||||
|
dataset.changed_on_humanized = None
|
||||||
|
dataset.created_by = None
|
||||||
|
dataset.created_on = None
|
||||||
|
dataset.created_on_humanized = None
|
||||||
|
dataset.tags = []
|
||||||
|
dataset.owners = []
|
||||||
|
dataset.is_virtual = False
|
||||||
|
dataset.database_id = 1
|
||||||
|
dataset.schema_perm = f"[{database_name}].[{schema}]"
|
||||||
|
dataset.database = MagicMock()
|
||||||
|
dataset.database.database_name = database_name
|
||||||
|
dataset.sql = None
|
||||||
|
dataset.main_dttm_col = None
|
||||||
|
dataset.offset = 0
|
||||||
|
dataset.cache_timeout = 0
|
||||||
|
dataset.params = None
|
||||||
|
dataset.template_params = None
|
||||||
|
dataset.extra = None
|
||||||
|
dataset.uuid = f"dataset-uuid-{dataset_id}"
|
||||||
|
dataset.columns = []
|
||||||
|
dataset.metrics = []
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mcp_server():
|
||||||
|
return mcp
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_auth():
|
||||||
|
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = 1
|
||||||
|
mock_user.username = "admin"
|
||||||
|
mock_get_user.return_value = mock_user
|
||||||
|
yield mock_get_user
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateDataset:
|
||||||
|
"""Tests for the create_dataset MCP tool."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_dao_and_security(self):
|
||||||
|
"""Default: valid database exists and access is granted.
|
||||||
|
|
||||||
|
Patches the pre-command access check so individual tests that only care
|
||||||
|
about command behavior don't need to replicate this setup.
|
||||||
|
"""
|
||||||
|
with patch(_DAO_PATH) as mock_dao, patch(_SEC_PATH) as mock_sec:
|
||||||
|
mock_dao.get_database_by_id.return_value = MagicMock(
|
||||||
|
id=1, database_name="test_db"
|
||||||
|
)
|
||||||
|
yield mock_dao, mock_sec
|
||||||
|
|
||||||
|
@patch(_CMD_PATH)
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_success(self, mock_command_class, mcp_server):
|
||||||
|
"""Happy path: tool creates dataset and returns DatasetInfo."""
|
||||||
|
mock_dataset = _make_mock_dataset()
|
||||||
|
mock_command = MagicMock()
|
||||||
|
mock_command.run.return_value = mock_dataset
|
||||||
|
mock_command_class.return_value = mock_command
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"database_id": 1,
|
||||||
|
"schema": "public",
|
||||||
|
"table_name": "orders",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.content is not None
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["id"] == 42
|
||||||
|
assert data["table_name"] == "orders"
|
||||||
|
assert data["schema"] == "public"
|
||||||
|
|
||||||
|
call_kwargs = mock_command_class.call_args[0][0]
|
||||||
|
assert call_kwargs["database"] == 1
|
||||||
|
assert call_kwargs["schema"] == "public"
|
||||||
|
assert call_kwargs["table_name"] == "orders"
|
||||||
|
assert "owners" not in call_kwargs
|
||||||
|
|
||||||
|
@patch(_CMD_PATH)
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_with_owners(self, mock_command_class, mcp_server):
|
||||||
|
"""Owners list is forwarded to the command when supplied."""
|
||||||
|
mock_dataset = _make_mock_dataset()
|
||||||
|
mock_command = MagicMock()
|
||||||
|
mock_command.run.return_value = mock_dataset
|
||||||
|
mock_command_class.return_value = mock_command
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"database_id": 2,
|
||||||
|
"schema": "sales",
|
||||||
|
"table_name": "transactions",
|
||||||
|
"owners": [5, 10],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["id"] == 42
|
||||||
|
|
||||||
|
call_kwargs = mock_command_class.call_args[0][0]
|
||||||
|
assert call_kwargs["owners"] == [5, 10]
|
||||||
|
|
||||||
|
@patch(_CMD_PATH)
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_already_exists(self, mock_command_class, mcp_server):
|
||||||
|
"""Returns DatasetExistsError when a dataset for the table already exists.
|
||||||
|
|
||||||
|
CreateDatasetCommand.validate() wraps DatasetExistsValidationError inside
|
||||||
|
DatasetInvalidError, so simulate the real command shape.
|
||||||
|
"""
|
||||||
|
from superset.commands.dataset.exceptions import (
|
||||||
|
DatasetExistsValidationError,
|
||||||
|
DatasetInvalidError,
|
||||||
|
)
|
||||||
|
from superset.sql.parse import Table
|
||||||
|
|
||||||
|
exc = DatasetInvalidError()
|
||||||
|
exc.append(DatasetExistsValidationError(Table("orders", "public", None)))
|
||||||
|
|
||||||
|
mock_command = MagicMock()
|
||||||
|
mock_command.run.side_effect = exc
|
||||||
|
mock_command_class.return_value = mock_command
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"database_id": 1,
|
||||||
|
"schema": "public",
|
||||||
|
"table_name": "orders",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["error_type"] == "DatasetExistsError"
|
||||||
|
assert "error" in data
|
||||||
|
|
||||||
|
@patch(_CMD_PATH)
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_table_not_found(self, mock_command_class, mcp_server):
|
||||||
|
"""Returns TableNotFoundError when the physical table does not exist in the DB.
|
||||||
|
|
||||||
|
CreateDatasetCommand.validate() wraps TableNotFoundValidationError inside
|
||||||
|
DatasetInvalidError, so simulate the real command shape.
|
||||||
|
"""
|
||||||
|
from superset.commands.dataset.exceptions import (
|
||||||
|
DatasetInvalidError,
|
||||||
|
TableNotFoundValidationError,
|
||||||
|
)
|
||||||
|
from superset.sql.parse import Table
|
||||||
|
|
||||||
|
exc = DatasetInvalidError()
|
||||||
|
exc.append(TableNotFoundValidationError(Table("missing_table", "public", None)))
|
||||||
|
|
||||||
|
mock_command = MagicMock()
|
||||||
|
mock_command.run.side_effect = exc
|
||||||
|
mock_command_class.return_value = mock_command
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"database_id": 1,
|
||||||
|
"schema": "public",
|
||||||
|
"table_name": "missing_table",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["error_type"] == "TableNotFoundError"
|
||||||
|
|
||||||
|
@patch(_CMD_PATH)
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_unexpected_error(
|
||||||
|
self, mock_command_class, mcp_server
|
||||||
|
):
|
||||||
|
"""Unexpected exceptions are caught and returned as InternalError."""
|
||||||
|
mock_command = MagicMock()
|
||||||
|
mock_command.run.side_effect = RuntimeError("DB connection lost")
|
||||||
|
mock_command_class.return_value = mock_command
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"database_id": 1,
|
||||||
|
"schema": "public",
|
||||||
|
"table_name": "orders",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["error_type"] == "InternalError"
|
||||||
|
assert "DB connection lost" in data["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_missing_required_fields(self, mcp_server):
|
||||||
|
"""Missing required fields raise a validation error before the tool runs."""
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
with pytest.raises(ToolError):
|
||||||
|
await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
# database_id and table_name are omitted intentionally
|
||||||
|
"schema": "public",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch(_CMD_PATH)
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_returns_full_dataset_info(
|
||||||
|
self, mock_command_class, mcp_server
|
||||||
|
):
|
||||||
|
"""The returned DatasetInfo includes columns, metrics, and all core fields."""
|
||||||
|
mock_dataset = _make_mock_dataset(
|
||||||
|
dataset_id=99, table_name="sales", schema="dw"
|
||||||
|
)
|
||||||
|
|
||||||
|
col = MagicMock()
|
||||||
|
col.column_name = "amount"
|
||||||
|
col.verbose_name = "Amount"
|
||||||
|
col.type = "NUMERIC"
|
||||||
|
col.is_dttm = False
|
||||||
|
col.groupby = True
|
||||||
|
col.filterable = True
|
||||||
|
col.description = "Sale amount"
|
||||||
|
mock_dataset.columns = [col]
|
||||||
|
|
||||||
|
metric = MagicMock()
|
||||||
|
metric.metric_name = "total_sales"
|
||||||
|
metric.verbose_name = "Total Sales"
|
||||||
|
metric.expression = "SUM(amount)"
|
||||||
|
metric.description = "Sum of amounts"
|
||||||
|
metric.d3format = None
|
||||||
|
mock_dataset.metrics = [metric]
|
||||||
|
|
||||||
|
mock_command = MagicMock()
|
||||||
|
mock_command.run.return_value = mock_dataset
|
||||||
|
mock_command_class.return_value = mock_command
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"database_id": 1,
|
||||||
|
"schema": "dw",
|
||||||
|
"table_name": "sales",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["id"] == 99
|
||||||
|
assert data["table_name"] == "sales"
|
||||||
|
assert data["schema"] == "dw"
|
||||||
|
assert data["is_virtual"] is False
|
||||||
|
assert len(data["columns"]) == 1
|
||||||
|
assert data["columns"][0]["column_name"] == "amount"
|
||||||
|
assert len(data["metrics"]) == 1
|
||||||
|
assert data["metrics"][0]["metric_name"] == "total_sales"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_database_not_found(
|
||||||
|
self, mock_dao_and_security, mcp_server
|
||||||
|
):
|
||||||
|
"""Returns DatabaseNotFoundError when the database_id does not exist."""
|
||||||
|
mock_dao, _ = mock_dao_and_security
|
||||||
|
mock_dao.get_database_by_id.return_value = None
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{"request": {"database_id": 999, "table_name": "orders"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["error_type"] == "DatabaseNotFoundError"
|
||||||
|
assert "999" in data["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_access_denied(
|
||||||
|
self, mock_dao_and_security, mcp_server
|
||||||
|
):
|
||||||
|
"""Returns AccessDeniedError when the caller lacks table-level access."""
|
||||||
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
|
from superset.exceptions import SupersetSecurityException
|
||||||
|
|
||||||
|
access_exc = SupersetSecurityException(
|
||||||
|
SupersetError(
|
||||||
|
message="Access denied",
|
||||||
|
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||||
|
level=ErrorLevel.ERROR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Patch _SEC_PATH explicitly inside the test with side_effect pre-configured
|
||||||
|
# so the raise_for_access mock is guaranteed to raise during the tool call.
|
||||||
|
# The autouse mock_dao_and_security fixture keeps the DAO mock active (database
|
||||||
|
# found), and this inner patch overrides the security manager mock only.
|
||||||
|
with patch(_SEC_PATH) as mock_sec_override:
|
||||||
|
mock_sec_override.raise_for_access.side_effect = access_exc
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{"request": {"database_id": 1, "table_name": "secret_table"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["error_type"] == "AccessDeniedError"
|
||||||
|
|
||||||
|
@patch(_CMD_PATH)
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_no_schema(
|
||||||
|
self, mock_command_class, mock_dao_and_security, mcp_server
|
||||||
|
):
|
||||||
|
"""schema is optional; omitting it does not pass it to the command."""
|
||||||
|
mock_dataset = _make_mock_dataset()
|
||||||
|
mock_command = MagicMock()
|
||||||
|
mock_command.run.return_value = mock_dataset
|
||||||
|
mock_command_class.return_value = mock_command
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{"request": {"database_id": 1, "table_name": "orders"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["id"] == 42
|
||||||
|
|
||||||
|
call_kwargs = mock_command_class.call_args[0][0]
|
||||||
|
assert "schema" not in call_kwargs
|
||||||
|
|
||||||
|
@patch(_CMD_PATH)
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_create_dataset_with_catalog(
|
||||||
|
self, mock_command_class, mock_dao_and_security, mcp_server
|
||||||
|
):
|
||||||
|
"""catalog is forwarded to CreateDatasetCommand when provided."""
|
||||||
|
mock_dataset = _make_mock_dataset()
|
||||||
|
mock_command = MagicMock()
|
||||||
|
mock_command.run.return_value = mock_dataset
|
||||||
|
mock_command_class.return_value = mock_command
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_dataset",
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"database_id": 1,
|
||||||
|
"table_name": "orders",
|
||||||
|
"catalog": "prod_catalog",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["id"] == 42
|
||||||
|
|
||||||
|
call_kwargs = mock_command_class.call_args[0][0]
|
||||||
|
assert call_kwargs["catalog"] == "prod_catalog"
|
||||||
|
assert "schema" not in call_kwargs
|
||||||
Reference in New Issue
Block a user