mirror of
https://github.com/apache/superset.git
synced 2026-06-10 18:19:28 +00:00
Compare commits
2 Commits
feat/run-u
...
fix/nvd3-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4debd2d01a | ||
|
|
87be424f9c |
2
.github/workflows/bump-python-package.yml
vendored
2
.github/workflows/bump-python-package.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install uv
|
||||
run: pip install uv==0.11.17
|
||||
run: pip install uv
|
||||
|
||||
- name: supersetbot bump-python -p "${{ github.event.inputs.package }}"
|
||||
env:
|
||||
|
||||
2
.github/workflows/dependency-review.yml
vendored
2
.github/workflows/dependency-review.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
# the latest version. It's MIT: https://github.com/nbubna/store/blob/master/LICENSE-MIT
|
||||
# pkg:npm/node-forge@1.3.1
|
||||
# selecting BSD-3-Clause licensing terms for node-forge to ensure compatibility with Apache
|
||||
allow-dependencies-licenses: pkg:npm/rgbcolor, pkg:npm/jszip@3.10.1
|
||||
allow-dependencies-licenses: pkg:npm/store2@2.14.2, pkg:npm/node-forge@1.3.1, pkg:npm/rgbcolor, pkg:npm/jszip@3.10.1
|
||||
|
||||
python-dependency-liccheck:
|
||||
# NOTE: Configuration for liccheck lives in our pyproject.yml.
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
name: Sync requirements for Python dependency PRs
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
sync-python-dep-requirements:
|
||||
if: ${{ github.event.pull_request.user.login == 'dependabot[bot]' }}
|
||||
runs-on: ubuntu-slim
|
||||
steps:
|
||||
- name: Fetch Dependabot metadata
|
||||
id: dependabot-metadata
|
||||
uses: dependabot/fetch-metadata@25dd0e34f4fe68f24cc83900b1fe3fe149efef98 # v3.1.0
|
||||
|
||||
- name: Checkout source code
|
||||
if: ${{ steps.dependabot-metadata.outputs.package-ecosystem == 'pip' }}
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python
|
||||
if: ${{ steps.dependabot-metadata.outputs.package-ecosystem == 'pip' }}
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
|
||||
with:
|
||||
python-version-file: 'pyproject.toml'
|
||||
|
||||
- name: Install uv
|
||||
run: pip install uv==0.11.17
|
||||
|
||||
- name: Sync requirements
|
||||
if: ${{ steps.dependabot-metadata.outputs.package-ecosystem == 'pip' }}
|
||||
run: ./scripts/uv-pip-compile.sh
|
||||
|
||||
- name: Push changes to remote PRs
|
||||
if: ${{ steps.dependabot-metadata.outputs.package-ecosystem == 'pip' }}
|
||||
run: |
|
||||
git config user.name 'github-actions[bot]'
|
||||
git config user.email '41898282+github-actions[bot]@users.noreply.github.com'
|
||||
git add requirements
|
||||
git diff --cached --quiet && exit 0
|
||||
git commit --signoff "build(deps) sync pinned requirements for Dependabot pip PRs"
|
||||
git push origin "HEAD:refs/heads/${GITHUB_EVENT_PULL_REQUEST_HEAD_REF}"
|
||||
env:
|
||||
GITHUB_EVENT_PULL_REQUEST_HEAD_REF: ${{github.event.pull_request.head.ref}}
|
||||
@@ -31,7 +31,7 @@ if [ -z "$RUNNING_IN_DOCKER" ]; then
|
||||
-w /app \
|
||||
-e RUNNING_IN_DOCKER=1 \
|
||||
python:${PYTHON_VERSION}-slim \
|
||||
bash -c "pip install uv==0.11.17 && ./scripts/uv-pip-compile.sh $*"
|
||||
bash -c "pip install uv && ./scripts/uv-pip-compile.sh $*"
|
||||
|
||||
exit $?
|
||||
fi
|
||||
|
||||
@@ -162,7 +162,7 @@ export function generateMultiLineTooltipContent(d, xFormatter, yFormatters) {
|
||||
|
||||
tooltip += '</tbody></table>';
|
||||
|
||||
return tooltip;
|
||||
return dompurify.sanitize(tooltip);
|
||||
}
|
||||
|
||||
export function generateTimePivotTooltip(d, xFormatter, yFormatter) {
|
||||
@@ -223,7 +223,7 @@ export function generateBubbleTooltipContent({
|
||||
s += createHTMLRow(getLabel(sizeField), sizeFormatter(point.size));
|
||||
s += '</table>';
|
||||
|
||||
return s;
|
||||
return dompurify.sanitize(s);
|
||||
}
|
||||
|
||||
// shouldRemove indicates whether the nvtooltips should be removed from the DOM
|
||||
@@ -287,9 +287,11 @@ export function tipFactory(layer) {
|
||||
? layer.descriptionColumns.map(c => d[c])
|
||||
: Object.values(d);
|
||||
|
||||
return `<div><strong>${title}</strong></div><br/><div>${body.join(
|
||||
', ',
|
||||
)}</div>`;
|
||||
return dompurify.sanitize(
|
||||
`<div><strong>${title}</strong></div><br/><div>${body.join(
|
||||
', ',
|
||||
)}</div>`,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,9 @@ import {
|
||||
computeYDomain,
|
||||
getTimeOrNumberFormatter,
|
||||
formatLabel,
|
||||
generateBubbleTooltipContent,
|
||||
generateMultiLineTooltipContent,
|
||||
tipFactory,
|
||||
} from '../src/utils';
|
||||
|
||||
const DATA = [
|
||||
@@ -181,4 +184,61 @@ describe('nvd3/utils', () => {
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('tooltip HTML sanitization', () => {
|
||||
const identity = (v: unknown) => v;
|
||||
|
||||
test('generateBubbleTooltipContent strips scripts from entity/group', () => {
|
||||
const html = generateBubbleTooltipContent({
|
||||
point: {
|
||||
name: '<img src=x onerror="alert(1)">',
|
||||
group: '<script>alert(2)</script>',
|
||||
color: 'red',
|
||||
x: 1,
|
||||
y: 2,
|
||||
size: 3,
|
||||
},
|
||||
entity: 'name',
|
||||
xField: 'x',
|
||||
yField: 'y',
|
||||
sizeField: 'size',
|
||||
xFormatter: identity,
|
||||
yFormatter: identity,
|
||||
sizeFormatter: identity,
|
||||
});
|
||||
|
||||
expect(html).not.toContain('onerror');
|
||||
expect(html).not.toContain('<script>');
|
||||
});
|
||||
|
||||
test('generateMultiLineTooltipContent strips scripts from series keys', () => {
|
||||
const html = generateMultiLineTooltipContent(
|
||||
{
|
||||
value: 'x',
|
||||
series: [
|
||||
{ key: '<img src=x onerror="alert(1)">', color: 'red', value: 1 },
|
||||
],
|
||||
},
|
||||
identity,
|
||||
[identity],
|
||||
);
|
||||
|
||||
expect(html).not.toContain('onerror');
|
||||
});
|
||||
|
||||
test('tipFactory strips scripts from annotation data values', () => {
|
||||
const tip = tipFactory({
|
||||
titleColumn: 'title',
|
||||
name: 'layer',
|
||||
descriptionColumns: ['desc'],
|
||||
});
|
||||
const html = tip.html()({
|
||||
title: '<img src=x onerror="alert(1)">',
|
||||
desc: '<script>alert(2)</script>',
|
||||
});
|
||||
|
||||
expect(html).not.toContain('onerror');
|
||||
expect(html).not.toContain('<script>');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -159,6 +159,14 @@ User and Role Management:
|
||||
- list_roles: List roles with filtering (1-based pagination, admin only)
|
||||
- get_role_info: Get role details by ID (admin only)
|
||||
|
||||
Row Level Security (Admin only):
|
||||
- list_rls_filters: List RLS filters with filtering and search (1-based pagination)
|
||||
- get_rls_filter_info: Get detailed RLS filter info by ID (tables, roles, clause)
|
||||
|
||||
Plugins (Admin only):
|
||||
- list_plugins: List dynamic plugins with filtering and search (1-based pagination)
|
||||
- get_plugin_info: Get detailed plugin info by ID (name, key, bundle URL)
|
||||
|
||||
Dataset Management:
|
||||
- list_datasets: List datasets with advanced filters (1-based pagination)
|
||||
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
|
||||
@@ -401,9 +409,10 @@ IMPORTANT - Tool-Only Interaction:
|
||||
|
||||
General usage tips:
|
||||
- All listing tools use 1-based pagination (first page is 1)
|
||||
- Use get_schema to discover filterable columns, sortable columns, and default columns
|
||||
for chart/dataset/dashboard/database. For action_log and task tools, consult each
|
||||
tool's docstring — filterable and sortable columns are listed there directly.
|
||||
- Use get_schema (chart/dataset/dashboard/database) to discover filterable columns,
|
||||
sortable columns, and default columns for those resource types
|
||||
- For action_log, task, list_rls_filters, and list_plugins tools, filterable/sortable
|
||||
columns are listed inline in each tool's docstring — get_schema does not cover these
|
||||
- Use 'filters' parameter for advanced queries with filter columns from get_schema
|
||||
- IDs can be integer or UUID format where supported
|
||||
- All tools return structured, Pydantic-typed responses
|
||||
@@ -708,10 +717,18 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
|
||||
generate_explore_link,
|
||||
)
|
||||
from superset.mcp_service.plugin.tool import ( # noqa: F401, E402
|
||||
get_plugin_info,
|
||||
list_plugins,
|
||||
)
|
||||
from superset.mcp_service.query.tool import ( # noqa: F401, E402
|
||||
get_query_info,
|
||||
list_queries,
|
||||
)
|
||||
from superset.mcp_service.rls.tool import ( # noqa: F401, E402
|
||||
get_rls_filter_info,
|
||||
list_rls_filters,
|
||||
)
|
||||
from superset.mcp_service.role.tool import ( # noqa: F401, E402
|
||||
get_role_info,
|
||||
list_roles,
|
||||
|
||||
16
superset/mcp_service/plugin/__init__.py
Normal file
16
superset/mcp_service/plugin/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
23
superset/mcp_service/plugin/dao.py
Normal file
23
superset/mcp_service/plugin/dao.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from superset.daos.base import BaseDAO
|
||||
from superset.models.dynamic_plugins import DynamicPlugin
|
||||
|
||||
|
||||
class DynamicPluginDAO(BaseDAO[DynamicPlugin]):
|
||||
pass
|
||||
213
superset/mcp_service/plugin/schemas.py
Normal file
213
superset/mcp_service/plugin/schemas.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for dynamic plugin responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from superset.mcp_service.system.schemas import PaginationInfo
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
parse_json_or_list,
|
||||
parse_json_or_model_list,
|
||||
)
|
||||
|
||||
DEFAULT_PLUGIN_COLUMNS = ["id", "name", "key", "bundle_url"]
|
||||
|
||||
ALL_PLUGIN_COLUMNS = [
|
||||
"id",
|
||||
"name",
|
||||
"key",
|
||||
"bundle_url",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
]
|
||||
|
||||
SORTABLE_PLUGIN_COLUMNS = ["id", "name", "key", "changed_on", "created_on"]
|
||||
|
||||
|
||||
class PluginColumnFilter(ColumnOperator):
|
||||
"""Filter object for plugin listing."""
|
||||
|
||||
col: Literal["name", "key"] = Field(..., description="Column to filter on.")
|
||||
opr: ColumnOperatorEnum = Field(..., description="Operator to use.")
|
||||
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by"
|
||||
)
|
||||
|
||||
|
||||
class PluginInfo(BaseModel):
|
||||
id: int | None = Field(None, description="Plugin ID")
|
||||
name: str | None = Field(None, description="Plugin display name")
|
||||
key: str | None = Field(None, description="Plugin key (corresponds to viz_type)")
|
||||
bundle_url: str | None = Field(None, description="URL to the plugin bundle")
|
||||
changed_on: str | datetime | None = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
created_on: str | datetime | None = Field(None, description="Creation timestamp")
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
ser_json_timedelta="iso8601",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
|
||||
data = serializer(self)
|
||||
if info.context and isinstance(info.context, dict):
|
||||
select_columns = info.context.get("select_columns")
|
||||
if select_columns:
|
||||
requested_fields = set(select_columns)
|
||||
return {k: v for k, v in data.items() if k in requested_fields}
|
||||
return data
|
||||
|
||||
|
||||
class PluginList(BaseModel):
|
||||
plugins: List[PluginInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: List[str] = Field(default_factory=list)
|
||||
columns_loaded: List[str] = Field(default_factory=list)
|
||||
columns_available: List[str] = Field(default_factory=list)
|
||||
sortable_columns: List[str] = Field(default_factory=list)
|
||||
filters_applied: List[PluginColumnFilter] = Field(default_factory=list)
|
||||
pagination: PaginationInfo | None = None
|
||||
timestamp: datetime | None = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ListPluginsRequest(BaseModel):
|
||||
"""Request schema for list_plugins."""
|
||||
|
||||
filters: Annotated[
|
||||
List[PluginColumnFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of filter objects (col, opr, value). "
|
||||
"Cannot be used with search.",
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
List[str],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="Columns to include in response. Defaults to common columns.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Text search on plugin name or key. "
|
||||
"Cannot be used with filters.",
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
str | None, Field(default=None, description="Column to order results by")
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(default="desc", description="Sort direction"),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[PluginColumnFilter]:
|
||||
return parse_json_or_model_list(v, PluginColumnFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_columns(cls, v: Any) -> List[str]:
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListPluginsRequest":
|
||||
if self.search and self.filters:
|
||||
raise ValueError("Cannot use both 'search' and 'filters' simultaneously.")
|
||||
return self
|
||||
|
||||
|
||||
class PluginError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Type of error")
|
||||
timestamp: str | datetime | None = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
@classmethod
|
||||
def create(cls, error: str, error_type: str) -> "PluginError":
|
||||
from datetime import timezone
|
||||
|
||||
return cls(
|
||||
error=error, error_type=error_type, timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class GetPluginInfoRequest(BaseModel):
|
||||
"""Request schema for get_plugin_info."""
|
||||
|
||||
identifier: Annotated[
|
||||
int,
|
||||
Field(description="Plugin ID"),
|
||||
]
|
||||
|
||||
|
||||
def serialize_plugin_object(plugin: Any) -> PluginInfo | None:
|
||||
if not plugin:
|
||||
return None
|
||||
|
||||
return PluginInfo(
|
||||
id=getattr(plugin, "id", None),
|
||||
name=getattr(plugin, "name", None),
|
||||
key=getattr(plugin, "key", None),
|
||||
bundle_url=getattr(plugin, "bundle_url", None),
|
||||
changed_on=getattr(plugin, "changed_on", None),
|
||||
created_on=getattr(plugin, "created_on", None),
|
||||
)
|
||||
24
superset/mcp_service/plugin/tool/__init__.py
Normal file
24
superset/mcp_service/plugin/tool/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_plugin_info import get_plugin_info
|
||||
from .list_plugins import list_plugins
|
||||
|
||||
__all__ = [
|
||||
"list_plugins",
|
||||
"get_plugin_info",
|
||||
]
|
||||
101
superset/mcp_service/plugin/tool/get_plugin_info.py
Normal file
101
superset/mcp_service/plugin/tool/get_plugin_info.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get plugin info FastMCP tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.plugin.schemas import (
|
||||
GetPluginInfoRequest,
|
||||
PluginError,
|
||||
PluginInfo,
|
||||
serialize_plugin_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="DynamicPlugin",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get plugin info",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_plugin_info(
|
||||
request: GetPluginInfoRequest, ctx: Context
|
||||
) -> PluginInfo | PluginError:
|
||||
"""Get dynamic plugin details by ID. Requires admin access.
|
||||
|
||||
Returns full plugin configuration including name, key, and bundle URL.
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
{"identifier": 1}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Retrieving plugin information: identifier=%s" % (request.identifier,)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.mcp_service.plugin.dao import DynamicPluginDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_plugin_info.lookup"):
|
||||
get_tool = ModelGetInfoCore(
|
||||
dao_class=DynamicPluginDAO,
|
||||
output_schema=PluginInfo,
|
||||
error_schema=PluginError,
|
||||
serializer=serialize_plugin_object,
|
||||
supports_slug=False,
|
||||
logger=logger,
|
||||
)
|
||||
result = get_tool.run_tool(request.identifier)
|
||||
|
||||
if isinstance(result, PluginInfo):
|
||||
await ctx.info(
|
||||
"Plugin retrieved: id=%s, name=%s, key=%s"
|
||||
% (result.id, result.name, result.key)
|
||||
)
|
||||
else:
|
||||
await ctx.warning(
|
||||
"Plugin retrieval failed: error_type=%s, error=%s"
|
||||
% (result.error_type, result.error)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Plugin info retrieval failed: identifier=%s, error=%s"
|
||||
% (request.identifier, str(e))
|
||||
)
|
||||
return PluginError(
|
||||
error=f"Failed to get plugin info: {str(e)}",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
123
superset/mcp_service/plugin/tool/list_plugins.py
Normal file
123
superset/mcp_service/plugin/tool/list_plugins.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
List plugins FastMCP tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.plugin.schemas import (
|
||||
ALL_PLUGIN_COLUMNS,
|
||||
DEFAULT_PLUGIN_COLUMNS,
|
||||
ListPluginsRequest,
|
||||
PluginColumnFilter,
|
||||
PluginError,
|
||||
PluginInfo,
|
||||
PluginList,
|
||||
serialize_plugin_object,
|
||||
SORTABLE_PLUGIN_COLUMNS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_LIST_PLUGINS_REQUEST = ListPluginsRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
class_permission_name="DynamicPlugin",
|
||||
annotations=ToolAnnotations(
|
||||
title="List plugins",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def list_plugins(
|
||||
request: ListPluginsRequest | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> PluginList | PluginError:
|
||||
"""List dynamic plugins registered in this Superset instance. Requires admin access.
|
||||
|
||||
Returns plugin metadata including name, key, and bundle URL.
|
||||
|
||||
Sortable columns for order_column: id, name, key, changed_on, created_on
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_plugins")
|
||||
|
||||
request = request or _DEFAULT_LIST_PLUGINS_REQUEST.model_copy(deep=True)
|
||||
|
||||
await ctx.info(
|
||||
"Listing plugins: page=%s, page_size=%s, search=%s"
|
||||
% (request.page, request.page_size, request.search)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.mcp_service.plugin.dao import DynamicPluginDAO
|
||||
|
||||
def _serialize_plugin(obj: object, cols: list[str]) -> PluginInfo | None:
|
||||
return serialize_plugin_object(obj)
|
||||
|
||||
list_tool = ModelListCore(
|
||||
dao_class=DynamicPluginDAO,
|
||||
output_schema=PluginInfo,
|
||||
item_serializer=_serialize_plugin,
|
||||
filter_type=PluginColumnFilter,
|
||||
default_columns=DEFAULT_PLUGIN_COLUMNS,
|
||||
search_columns=["name", "key"],
|
||||
list_field_name="plugins",
|
||||
output_list_schema=PluginList,
|
||||
all_columns=ALL_PLUGIN_COLUMNS,
|
||||
sortable_columns=SORTABLE_PLUGIN_COLUMNS,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
with event_logger.log_context(action="mcp.list_plugins.query"):
|
||||
result = list_tool.run_tool(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=request.select_columns,
|
||||
order_column=request.order_column,
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Plugins listed: count=%s, total_count=%s"
|
||||
% (len(result.plugins), result.total_count)
|
||||
)
|
||||
|
||||
columns_to_filter = result.columns_requested
|
||||
with event_logger.log_context(action="mcp.list_plugins.serialization"):
|
||||
return result.model_dump(
|
||||
mode="json",
|
||||
context={"select_columns": columns_to_filter},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Plugin listing failed: error=%s, error_type=%s"
|
||||
% (str(e), type(e).__name__)
|
||||
)
|
||||
raise
|
||||
@@ -140,7 +140,7 @@ def user_can_view_data_model_metadata() -> bool:
|
||||
|
||||
|
||||
def filter_user_directory_fields(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Remove fields that expose users, roles, owners, or access metadata."""
|
||||
"""Remove fields that expose users, owners, or access metadata."""
|
||||
return {
|
||||
key: value for key, value in data.items() if key not in USER_DIRECTORY_FIELDS
|
||||
}
|
||||
|
||||
16
superset/mcp_service/rls/__init__.py
Normal file
16
superset/mcp_service/rls/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
255
superset/mcp_service/rls/schemas.py
Normal file
255
superset/mcp_service/rls/schemas.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for row level security filter responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from superset.mcp_service.system.schemas import PaginationInfo
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
parse_json_or_list,
|
||||
parse_json_or_model_list,
|
||||
)
|
||||
|
||||
DEFAULT_RLS_COLUMNS = ["id", "name", "filter_type", "clause"]
|
||||
|
||||
ALL_RLS_COLUMNS = [
|
||||
"id",
|
||||
"name",
|
||||
"filter_type",
|
||||
"tables",
|
||||
"roles",
|
||||
"clause",
|
||||
"group_key",
|
||||
"changed_on",
|
||||
]
|
||||
|
||||
SORTABLE_RLS_COLUMNS = ["id", "name", "filter_type", "changed_on"]
|
||||
|
||||
|
||||
class RlsColumnFilter(ColumnOperator):
|
||||
"""Filter object for RLS filter listing."""
|
||||
|
||||
col: Literal["name", "filter_type"] = Field(
|
||||
...,
|
||||
description="Column to filter on.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(..., description="Operator to use.")
|
||||
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by"
|
||||
)
|
||||
|
||||
|
||||
class RlsTableRef(BaseModel):
|
||||
id: int | None = Field(None, description="Table ID")
|
||||
table_name: str | None = Field(None, description="Table name")
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class RlsRoleRef(BaseModel):
|
||||
id: int | None = Field(None, description="Role ID")
|
||||
name: str | None = Field(None, description="Role name")
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class RlsFilterInfo(BaseModel):
|
||||
id: int | None = Field(None, description="RLS filter ID")
|
||||
name: str | None = Field(None, description="RLS filter name")
|
||||
filter_type: str | None = Field(None, description="Filter type: Regular or Base")
|
||||
tables: List[RlsTableRef] | None = Field(
|
||||
None, description="Tables this filter applies to"
|
||||
)
|
||||
roles: List[RlsRoleRef] | None = Field(
|
||||
None, description="Roles this filter applies to"
|
||||
)
|
||||
clause: str | None = Field(None, description="SQL WHERE clause")
|
||||
group_key: str | None = Field(
|
||||
None, description="Group key for Base filter grouping"
|
||||
)
|
||||
changed_on: str | datetime | None = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
ser_json_timedelta="iso8601",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
|
||||
data = serializer(self)
|
||||
if info.context and isinstance(info.context, dict):
|
||||
select_columns = info.context.get("select_columns")
|
||||
if select_columns:
|
||||
requested_fields = set(select_columns)
|
||||
return {k: v for k, v in data.items() if k in requested_fields}
|
||||
return data
|
||||
|
||||
|
||||
class RlsFilterList(BaseModel):
|
||||
rls_filters: List[RlsFilterInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: List[str] = Field(default_factory=list)
|
||||
columns_loaded: List[str] = Field(default_factory=list)
|
||||
columns_available: List[str] = Field(default_factory=list)
|
||||
sortable_columns: List[str] = Field(default_factory=list)
|
||||
filters_applied: List[RlsColumnFilter] = Field(default_factory=list)
|
||||
pagination: PaginationInfo | None = None
|
||||
timestamp: datetime | None = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ListRlsFiltersRequest(BaseModel):
|
||||
"""Request schema for list_rls_filters."""
|
||||
|
||||
filters: Annotated[
|
||||
List[RlsColumnFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of filter objects (col, opr, value). "
|
||||
"Cannot be used with search.",
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
List[str],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="Columns to include in response. Defaults to common columns.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Text search on filter name. Cannot be used with filters.",
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
str | None, Field(default=None, description="Column to order results by")
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(default="desc", description="Sort direction"),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[RlsColumnFilter]:
|
||||
return parse_json_or_model_list(v, RlsColumnFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_columns(cls, v: Any) -> List[str]:
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListRlsFiltersRequest":
|
||||
if self.search and self.filters:
|
||||
raise ValueError("Cannot use both 'search' and 'filters' simultaneously.")
|
||||
return self
|
||||
|
||||
|
||||
class RlsFilterError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Type of error")
|
||||
timestamp: str | datetime | None = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
@classmethod
|
||||
def create(cls, error: str, error_type: str) -> "RlsFilterError":
|
||||
from datetime import timezone
|
||||
|
||||
return cls(
|
||||
error=error, error_type=error_type, timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class GetRlsFilterInfoRequest(BaseModel):
|
||||
"""Request schema for get_rls_filter_info."""
|
||||
|
||||
identifier: Annotated[
|
||||
int,
|
||||
Field(description="RLS filter ID"),
|
||||
]
|
||||
|
||||
|
||||
def serialize_rls_filter_object(rls_filter: Any) -> RlsFilterInfo | None:
|
||||
if not rls_filter:
|
||||
return None
|
||||
|
||||
tables = [
|
||||
RlsTableRef(
|
||||
id=getattr(t, "id", None),
|
||||
table_name=getattr(t, "table_name", None),
|
||||
)
|
||||
for t in (getattr(rls_filter, "tables", None) or [])
|
||||
]
|
||||
|
||||
roles = [
|
||||
RlsRoleRef(
|
||||
id=getattr(r, "id", None),
|
||||
name=getattr(r, "name", None),
|
||||
)
|
||||
for r in (getattr(rls_filter, "roles", None) or [])
|
||||
]
|
||||
|
||||
return RlsFilterInfo(
|
||||
id=getattr(rls_filter, "id", None),
|
||||
name=getattr(rls_filter, "name", None),
|
||||
filter_type=getattr(rls_filter, "filter_type", None),
|
||||
tables=tables,
|
||||
roles=roles,
|
||||
clause=getattr(rls_filter, "clause", None),
|
||||
group_key=getattr(rls_filter, "group_key", None),
|
||||
changed_on=getattr(rls_filter, "changed_on", None),
|
||||
)
|
||||
24
superset/mcp_service/rls/tool/__init__.py
Normal file
24
superset/mcp_service/rls/tool/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_rls_filter_info import get_rls_filter_info
|
||||
from .list_rls_filters import list_rls_filters
|
||||
|
||||
__all__ = [
|
||||
"list_rls_filters",
|
||||
"get_rls_filter_info",
|
||||
]
|
||||
101
superset/mcp_service/rls/tool/get_rls_filter_info.py
Normal file
101
superset/mcp_service/rls/tool/get_rls_filter_info.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get RLS filter info FastMCP tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.rls.schemas import (
|
||||
GetRlsFilterInfoRequest,
|
||||
RlsFilterError,
|
||||
RlsFilterInfo,
|
||||
serialize_rls_filter_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="Row Level Security",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get RLS filter info",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_rls_filter_info(
|
||||
request: GetRlsFilterInfoRequest, ctx: Context
|
||||
) -> RlsFilterInfo | RlsFilterError:
|
||||
"""Get row level security filter details by ID. Requires admin access.
|
||||
|
||||
Returns full RLS filter configuration including name, type, tables, roles,
|
||||
and clause.
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
{"identifier": 1}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Retrieving RLS filter information: identifier=%s" % (request.identifier,)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.security import RLSDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_rls_filter_info.lookup"):
|
||||
get_tool = ModelGetInfoCore(
|
||||
dao_class=RLSDAO,
|
||||
output_schema=RlsFilterInfo,
|
||||
error_schema=RlsFilterError,
|
||||
serializer=serialize_rls_filter_object,
|
||||
supports_slug=False,
|
||||
logger=logger,
|
||||
)
|
||||
result = get_tool.run_tool(request.identifier)
|
||||
|
||||
if isinstance(result, RlsFilterInfo):
|
||||
await ctx.info(
|
||||
"RLS filter retrieved: id=%s, name=%s" % (result.id, result.name)
|
||||
)
|
||||
else:
|
||||
await ctx.warning(
|
||||
"RLS filter retrieval failed: error_type=%s, error=%s"
|
||||
% (result.error_type, result.error)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"RLS filter info retrieval failed: identifier=%s, error=%s"
|
||||
% (request.identifier, str(e))
|
||||
)
|
||||
return RlsFilterError(
|
||||
error=f"Failed to get RLS filter info: {str(e)}",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
147
superset/mcp_service/rls/tool/list_rls_filters.py
Normal file
147
superset/mcp_service/rls/tool/list_rls_filters.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
List RLS filters FastMCP tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.privacy import USER_DIRECTORY_FIELDS
|
||||
from superset.mcp_service.rls.schemas import (
|
||||
ALL_RLS_COLUMNS,
|
||||
DEFAULT_RLS_COLUMNS,
|
||||
ListRlsFiltersRequest,
|
||||
RlsColumnFilter,
|
||||
RlsFilterError,
|
||||
RlsFilterInfo,
|
||||
RlsFilterList,
|
||||
serialize_rls_filter_object,
|
||||
SORTABLE_RLS_COLUMNS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_LIST_RLS_FILTERS_REQUEST = ListRlsFiltersRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
class_permission_name="Row Level Security",
|
||||
annotations=ToolAnnotations(
|
||||
title="List RLS filters",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def list_rls_filters(
|
||||
request: ListRlsFiltersRequest | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> RlsFilterList | RlsFilterError:
|
||||
"""List row level security filters. Requires admin access.
|
||||
|
||||
Returns RLS filter metadata including name, filter type, tables, roles, and clause.
|
||||
|
||||
Sortable columns for order_column: id, name, filter_type, changed_on
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_rls_filters")
|
||||
|
||||
request = request or _DEFAULT_LIST_RLS_FILTERS_REQUEST.model_copy(deep=True)
|
||||
|
||||
await ctx.info(
|
||||
"Listing RLS filters: page=%s, page_size=%s, search=%s"
|
||||
% (request.page, request.page_size, request.search)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.security import RLSDAO
|
||||
|
||||
def _serialize_rls_filter(obj: object, cols: list[str]) -> RlsFilterInfo | None:
|
||||
return serialize_rls_filter_object(obj)
|
||||
|
||||
list_tool = ModelListCore(
|
||||
dao_class=RLSDAO,
|
||||
output_schema=RlsFilterInfo,
|
||||
item_serializer=_serialize_rls_filter,
|
||||
filter_type=RlsColumnFilter,
|
||||
default_columns=DEFAULT_RLS_COLUMNS,
|
||||
search_columns=["name"],
|
||||
list_field_name="rls_filters",
|
||||
output_list_schema=RlsFilterList,
|
||||
all_columns=ALL_RLS_COLUMNS,
|
||||
sortable_columns=SORTABLE_RLS_COLUMNS,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Strip USER_DIRECTORY_FIELDS (e.g. 'roles') before handing off to
|
||||
# run_tool, which would raise ValueError if all requested columns are
|
||||
# privacy-filtered. Roles are restored in the model_dump context below.
|
||||
run_select_columns: list[str] | None = None
|
||||
if request.select_columns:
|
||||
filtered = [
|
||||
c for c in request.select_columns if c not in USER_DIRECTORY_FIELDS
|
||||
]
|
||||
run_select_columns = filtered or None
|
||||
|
||||
with event_logger.log_context(action="mcp.list_rls_filters.query"):
|
||||
result = list_tool.run_tool(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=run_select_columns,
|
||||
order_column=request.order_column,
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"RLS filters listed: count=%s, total_count=%s"
|
||||
% (len(result.rls_filters), result.total_count)
|
||||
)
|
||||
|
||||
# Build column selection using ALL_RLS_COLUMNS as the source of truth,
|
||||
# bypassing the USER_DIRECTORY_FIELDS privacy filter applied by
|
||||
# ModelListCore. 'roles' in an RLS filter is which roles the filter
|
||||
# applies to — core filter data — not user-directory metadata (like
|
||||
# dashboard.roles, which exposes who has access to the resource).
|
||||
if request.select_columns:
|
||||
columns_to_filter = [
|
||||
c for c in request.select_columns if c in ALL_RLS_COLUMNS
|
||||
]
|
||||
if not columns_to_filter:
|
||||
columns_to_filter = list(DEFAULT_RLS_COLUMNS)
|
||||
else:
|
||||
columns_to_filter = list(DEFAULT_RLS_COLUMNS)
|
||||
|
||||
with event_logger.log_context(action="mcp.list_rls_filters.serialization"):
|
||||
return result.model_dump(
|
||||
mode="json",
|
||||
context={"select_columns": columns_to_filter},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"RLS filter listing failed: error=%s, error_type=%s"
|
||||
% (str(e), type(e).__name__)
|
||||
)
|
||||
raise
|
||||
16
tests/unit_tests/mcp_service/plugin/__init__.py
Normal file
16
tests/unit_tests/mcp_service/plugin/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/plugin/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/plugin/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
172
tests/unit_tests/mcp_service/plugin/tool/test_plugin_tools.py
Normal file
172
tests/unit_tests/mcp_service/plugin/tool/test_plugin_tools.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.plugin.schemas import ListPluginsRequest, PluginColumnFilter
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mock_plugin(
|
||||
plugin_id: int = 1,
|
||||
name: str = "My Plugin",
|
||||
key: str = "my_plugin",
|
||||
bundle_url: str = "https://example.com/plugin.js",
|
||||
) -> MagicMock:
|
||||
plugin = MagicMock()
|
||||
plugin.id = plugin_id
|
||||
plugin.name = name
|
||||
plugin.key = key
|
||||
plugin.bundle_url = bundle_url
|
||||
plugin.changed_on = None
|
||||
plugin.created_on = None
|
||||
return plugin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
class TestPluginColumnFilterSchema:
|
||||
def test_invalid_filter_column_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
PluginColumnFilter(col="bundle_url", opr="eq", value="test")
|
||||
|
||||
def test_valid_name_filter(self):
|
||||
f = PluginColumnFilter(col="name", opr="eq", value="test")
|
||||
assert f.col == "name"
|
||||
|
||||
def test_valid_key_filter(self):
|
||||
f = PluginColumnFilter(col="key", opr="eq", value="my_plugin")
|
||||
assert f.col == "key"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_plugins_basic(mock_list, mcp_server):
|
||||
plugin = create_mock_plugin()
|
||||
mock_list.return_value = ([plugin], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_plugins", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "plugins" in data
|
||||
assert len(data["plugins"]) == 1
|
||||
assert data["plugins"][0]["id"] == 1
|
||||
assert data["plugins"][0]["name"] == "My Plugin"
|
||||
assert data["plugins"][0]["key"] == "my_plugin"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_plugins_with_request(mock_list, mcp_server):
|
||||
plugin = create_mock_plugin()
|
||||
mock_list.return_value = ([plugin], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListPluginsRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_plugins", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 1
|
||||
assert data["total_count"] == 1
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_plugins_with_search(mock_list, mcp_server):
|
||||
plugin = create_mock_plugin(name="Custom Chart")
|
||||
mock_list.return_value = ([plugin], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListPluginsRequest(page=1, page_size=10, search="custom")
|
||||
result = await client.call_tool(
|
||||
"list_plugins", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["plugins"][0]["name"] == "Custom Chart"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_plugins_empty(mock_list, mcp_server):
|
||||
mock_list.return_value = ([], 0)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_plugins", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 0
|
||||
assert data["plugins"] == []
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugin_info_basic(mock_find, mcp_server):
|
||||
plugin = create_mock_plugin()
|
||||
mock_find.return_value = plugin
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_plugin_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 1
|
||||
assert data["name"] == "My Plugin"
|
||||
assert data["key"] == "my_plugin"
|
||||
assert data["bundle_url"] == "https://example.com/plugin.js"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.plugin.dao.DynamicPluginDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugin_info_not_found(mock_find, mcp_server):
|
||||
mock_find.return_value = None
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_plugin_info", {"request": {"identifier": 999}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "not_found"
|
||||
|
||||
|
||||
def test_list_plugins_request_rejects_search_and_filters():
|
||||
with pytest.raises(ValidationError):
|
||||
ListPluginsRequest(
|
||||
search="test",
|
||||
filters=[{"col": "name", "opr": "eq", "value": "x"}],
|
||||
)
|
||||
16
tests/unit_tests/mcp_service/rls/__init__.py
Normal file
16
tests/unit_tests/mcp_service/rls/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/rls/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/rls/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
245
tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py
Normal file
245
tests/unit_tests/mcp_service/rls/tool/test_rls_tools.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.rls.schemas import ListRlsFiltersRequest, RlsColumnFilter
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mock_rls_filter(
|
||||
filter_id: int = 1,
|
||||
name: str = "test_filter",
|
||||
filter_type: str = "Regular",
|
||||
clause: str = "user_id = {{current_user_id()}}",
|
||||
group_key: str | None = None,
|
||||
) -> MagicMock:
|
||||
rls_filter = MagicMock()
|
||||
rls_filter.id = filter_id
|
||||
rls_filter.name = name
|
||||
rls_filter.filter_type = filter_type
|
||||
rls_filter.clause = clause
|
||||
rls_filter.group_key = group_key
|
||||
rls_filter.changed_on = None
|
||||
|
||||
table = MagicMock()
|
||||
table.id = 1
|
||||
table.table_name = "sales"
|
||||
rls_filter.tables = [table]
|
||||
|
||||
role = MagicMock()
|
||||
role.id = 1
|
||||
role.name = "Alpha"
|
||||
rls_filter.roles = [role]
|
||||
|
||||
return rls_filter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
class TestRlsColumnFilterSchema:
|
||||
def test_invalid_filter_column_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
RlsColumnFilter(col="clause", opr="eq", value="test")
|
||||
|
||||
def test_valid_name_filter(self):
|
||||
f = RlsColumnFilter(col="name", opr="eq", value="test")
|
||||
assert f.col == "name"
|
||||
|
||||
def test_valid_filter_type_filter(self):
|
||||
f = RlsColumnFilter(col="filter_type", opr="eq", value="Regular")
|
||||
assert f.col == "filter_type"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_basic(mock_list, mcp_server):
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_list.return_value = ([rls_filter], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_rls_filters", {})
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "rls_filters" in data
|
||||
assert len(data["rls_filters"]) == 1
|
||||
assert data["rls_filters"][0]["id"] == 1
|
||||
assert data["rls_filters"][0]["name"] == "test_filter"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_with_request(mock_list, mcp_server):
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_list.return_value = ([rls_filter], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListRlsFiltersRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_rls_filters", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 1
|
||||
assert data["total_count"] == 1
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_with_search(mock_list, mcp_server):
|
||||
rls_filter = create_mock_rls_filter(name="user_filter")
|
||||
mock_list.return_value = ([rls_filter], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListRlsFiltersRequest(page=1, page_size=10, search="user")
|
||||
result = await client.call_tool(
|
||||
"list_rls_filters", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["rls_filters"][0]["name"] == "user_filter"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_returns_tables_and_roles(mock_list, mcp_server):
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_list.return_value = ([rls_filter], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListRlsFiltersRequest(
|
||||
page=1,
|
||||
page_size=10,
|
||||
select_columns=["id", "name", "tables", "roles"],
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_rls_filters", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
item = data["rls_filters"][0]
|
||||
assert "tables" in item
|
||||
assert item["tables"][0]["table_name"] == "sales"
|
||||
assert "roles" in item
|
||||
assert item["roles"][0]["name"] == "Alpha"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_empty(mock_list, mcp_server):
|
||||
mock_list.return_value = ([], 0)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_rls_filters", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 0
|
||||
assert data["rls_filters"] == []
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rls_filter_info_basic(mock_find, mcp_server):
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_find.return_value = rls_filter
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_rls_filter_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 1
|
||||
assert data["name"] == "test_filter"
|
||||
assert data["filter_type"] == "Regular"
|
||||
assert data["clause"] == "user_id = {{current_user_id()}}"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rls_filter_info_not_found(mock_find, mcp_server):
|
||||
mock_find.return_value = None
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_rls_filter_info", {"request": {"identifier": 999}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "not_found"
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rls_filter_info_includes_tables_and_roles(mock_find, mcp_server):
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_find.return_value = rls_filter
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_rls_filter_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["tables"][0]["table_name"] == "sales"
|
||||
assert data["roles"][0]["name"] == "Alpha"
|
||||
|
||||
|
||||
def test_list_rls_filters_request_rejects_search_and_filters():
|
||||
with pytest.raises(ValidationError):
|
||||
ListRlsFiltersRequest(
|
||||
search="test",
|
||||
filters=[{"col": "name", "opr": "eq", "value": "x"}],
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.security.RLSDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_rls_filters_roles_only_select_columns(mock_list, mcp_server):
|
||||
"""Regression: select_columns=['roles'] must not raise ValueError.
|
||||
|
||||
'roles' is in USER_DIRECTORY_FIELDS so ModelListCore would raise if it
|
||||
were the sole column passed to run_tool. The tool must strip it before
|
||||
calling run_tool and restore it in the model_dump context.
|
||||
"""
|
||||
rls_filter = create_mock_rls_filter()
|
||||
mock_list.return_value = ([rls_filter], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListRlsFiltersRequest(page=1, page_size=10, select_columns=["roles"])
|
||||
result = await client.call_tool(
|
||||
"list_rls_filters", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
item = data["rls_filters"][0]
|
||||
assert "roles" in item
|
||||
assert item["roles"][0]["name"] == "Alpha"
|
||||
Reference in New Issue
Block a user