mirror of
https://github.com/apache/superset.git
synced 2026-06-09 17:49:26 +00:00
Compare commits
26 Commits
fix/helm-r
...
oss-40340
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b6ff262fd | ||
|
|
576d40111b | ||
|
|
178fe56c9c | ||
|
|
de4da995b2 | ||
|
|
0f7f92011c | ||
|
|
9441240e5c | ||
|
|
08164e33bb | ||
|
|
894058fe3d | ||
|
|
6bd1b46216 | ||
|
|
ef4514f5ab | ||
|
|
e041f25385 | ||
|
|
d744f5715c | ||
|
|
fb60662353 | ||
|
|
207a7bf7f9 | ||
|
|
09a94fa26b | ||
|
|
7e088792b9 | ||
|
|
b6f545e61e | ||
|
|
952a6f3a23 | ||
|
|
8b551d3f74 | ||
|
|
709ef9b615 | ||
|
|
e9d46d843f | ||
|
|
9cc2deb903 | ||
|
|
03d25277ba | ||
|
|
bbe2f207d2 | ||
|
|
c381677dfd | ||
|
|
09572cd5ef |
4
.github/workflows/codeql-analysis.yml
vendored
4
.github/workflows/codeql-analysis.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@9e0d7b8d25671d64c341c19c0152d693099fb5ba # v4
|
||||
uses: github/codeql-action/init@7211b7c8077ea37d8641b6271f6a365a22a5fbfa # v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
@@ -53,6 +53,6 @@ jobs:
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
if: steps.check.outputs.python || steps.check.outputs.frontend
|
||||
uses: github/codeql-action/analyze@9e0d7b8d25671d64c341c19c0152d693099fb5ba # v4
|
||||
uses: github/codeql-action/analyze@7211b7c8077ea37d8641b6271f6a365a22a5fbfa # v4
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
|
||||
2
.github/workflows/superset-translations.yml
vendored
2
.github/workflows/superset-translations.yml
vendored
@@ -143,7 +143,7 @@ jobs:
|
||||
if: >-
|
||||
github.event_name == 'pull_request' &&
|
||||
steps.regression.outcome == 'failure'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7
|
||||
with:
|
||||
name: translation-regression
|
||||
path: |
|
||||
|
||||
@@ -29,7 +29,7 @@ maintainers:
|
||||
- name: craig-rueda
|
||||
email: craig@craigrueda.com
|
||||
url: https://github.com/craig-rueda
|
||||
version: 0.16.0 # See [README](https://github.com/apache/superset/blob/master/helm/superset/README.md#versioning) for version details.
|
||||
version: 0.15.5 # See [README](https://github.com/apache/superset/blob/master/helm/superset/README.md#versioning) for version details.
|
||||
dependencies:
|
||||
- name: postgresql
|
||||
version: 16.7.27
|
||||
|
||||
@@ -23,7 +23,7 @@ NOTE: This file is generated by helm-docs: https://github.com/norwoodj/helm-docs
|
||||
|
||||
# superset
|
||||
|
||||

|
||||

|
||||
|
||||
Apache Superset is a modern, enterprise-ready business intelligence web application
|
||||
|
||||
@@ -111,6 +111,9 @@ On helm this can be set on `extraSecretEnv.SUPERSET_SECRET_KEY` or `configOverri
|
||||
| init.resources | object | `{}` | |
|
||||
| init.tolerations | list | `[]` | |
|
||||
| init.topologySpreadConstraints | list | `[]` | TopologySpreadConstrains to be added to init job |
|
||||
| initImage.pullPolicy | string | `"IfNotPresent"` | |
|
||||
| initImage.repository | string | `"apache/superset"` | |
|
||||
| initImage.tag | string | `"dockerize"` | |
|
||||
| nameOverride | string | `nil` | Provide a name to override the name of the chart |
|
||||
| nodeSelector | object | `{}` | |
|
||||
| postgresql | object | see `values.yaml` | Configuration values for the postgresql dependency. ref: https://github.com/bitnami/charts/tree/main/bitnami/postgresql |
|
||||
|
||||
@@ -194,6 +194,11 @@ image:
|
||||
|
||||
imagePullSecrets: []
|
||||
|
||||
initImage:
|
||||
repository: apache/superset
|
||||
tag: dockerize
|
||||
pullPolicy: IfNotPresent
|
||||
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: 8088
|
||||
@@ -298,28 +303,15 @@ supersetNode:
|
||||
# @default -- a container waiting for postgres
|
||||
initContainers:
|
||||
- name: wait-for-postgres
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: "{{ .Values.image.pullPolicy }}"
|
||||
image: "{{ .Values.initImage.repository }}:{{ .Values.initImage.tag }}"
|
||||
imagePullPolicy: "{{ .Values.initImage.pullPolicy }}"
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: "{{ tpl .Values.envFromSecret . }}"
|
||||
command:
|
||||
- /bin/bash
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
# bash's /dev/tcp redirect performs a TCP connect; no external
|
||||
# `dockerize`, `nc`, or busybox needed. SECONDS-based deadline
|
||||
# mirrors the prior `dockerize -timeout 120s` behaviour.
|
||||
SECONDS=0
|
||||
until (echo > /dev/tcp/"$DB_HOST"/"$DB_PORT") 2>/dev/null; do
|
||||
if [ "$SECONDS" -ge 120 ]; then
|
||||
echo "timeout waiting for postgres at $DB_HOST:$DB_PORT after 120s" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "waiting for postgres at $DB_HOST:$DB_PORT (elapsed ${SECONDS}s)"
|
||||
sleep 2
|
||||
done
|
||||
echo "postgres at $DB_HOST:$DB_PORT is up"
|
||||
- dockerize -wait "tcp://$DB_HOST:$DB_PORT" -timeout 120s
|
||||
resources:
|
||||
limits:
|
||||
memory: "256Mi"
|
||||
@@ -415,31 +407,15 @@ supersetWorker:
|
||||
# @default -- a container waiting for postgres and redis
|
||||
initContainers:
|
||||
- name: wait-for-postgres-redis
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: "{{ .Values.image.pullPolicy }}"
|
||||
image: "{{ .Values.initImage.repository }}:{{ .Values.initImage.tag }}"
|
||||
imagePullPolicy: "{{ .Values.initImage.pullPolicy }}"
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: "{{ tpl .Values.envFromSecret . }}"
|
||||
command:
|
||||
- /bin/bash
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
# See supersetNode.initContainers for the rationale.
|
||||
SECONDS=0
|
||||
wait_for() {
|
||||
local host=$1 port=$2 name=$3
|
||||
until (echo > /dev/tcp/"$host"/"$port") 2>/dev/null; do
|
||||
if [ "$SECONDS" -ge 120 ]; then
|
||||
echo "timeout waiting for $name at $host:$port after 120s" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "waiting for $name at $host:$port (elapsed ${SECONDS}s)"
|
||||
sleep 2
|
||||
done
|
||||
echo "$name at $host:$port is up"
|
||||
}
|
||||
wait_for "$DB_HOST" "$DB_PORT" postgres
|
||||
wait_for "$REDIS_HOST" "$REDIS_PORT" redis
|
||||
- dockerize -wait "tcp://$DB_HOST:$DB_PORT" -wait "tcp://$REDIS_HOST:$REDIS_PORT" -timeout 120s
|
||||
resources:
|
||||
limits:
|
||||
memory: "256Mi"
|
||||
@@ -519,31 +495,15 @@ supersetCeleryBeat:
|
||||
# @default -- a container waiting for postgres
|
||||
initContainers:
|
||||
- name: wait-for-postgres-redis
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: "{{ .Values.image.pullPolicy }}"
|
||||
image: "{{ .Values.initImage.repository }}:{{ .Values.initImage.tag }}"
|
||||
imagePullPolicy: "{{ .Values.initImage.pullPolicy }}"
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: "{{ tpl .Values.envFromSecret . }}"
|
||||
command:
|
||||
- /bin/bash
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
# See supersetNode.initContainers for the rationale.
|
||||
SECONDS=0
|
||||
wait_for() {
|
||||
local host=$1 port=$2 name=$3
|
||||
until (echo > /dev/tcp/"$host"/"$port") 2>/dev/null; do
|
||||
if [ "$SECONDS" -ge 120 ]; then
|
||||
echo "timeout waiting for $name at $host:$port after 120s" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "waiting for $name at $host:$port (elapsed ${SECONDS}s)"
|
||||
sleep 2
|
||||
done
|
||||
echo "$name at $host:$port is up"
|
||||
}
|
||||
wait_for "$DB_HOST" "$DB_PORT" postgres
|
||||
wait_for "$REDIS_HOST" "$REDIS_PORT" redis
|
||||
- dockerize -wait "tcp://$DB_HOST:$DB_PORT" -wait "tcp://$REDIS_HOST:$REDIS_PORT" -timeout 120s
|
||||
resources:
|
||||
limits:
|
||||
memory: "256Mi"
|
||||
@@ -634,31 +594,15 @@ supersetCeleryFlower:
|
||||
# @default -- a container waiting for postgres and redis
|
||||
initContainers:
|
||||
- name: wait-for-postgres-redis
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: "{{ .Values.image.pullPolicy }}"
|
||||
image: "{{ .Values.initImage.repository }}:{{ .Values.initImage.tag }}"
|
||||
imagePullPolicy: "{{ .Values.initImage.pullPolicy }}"
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: "{{ tpl .Values.envFromSecret . }}"
|
||||
command:
|
||||
- /bin/bash
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
# See supersetNode.initContainers for the rationale.
|
||||
SECONDS=0
|
||||
wait_for() {
|
||||
local host=$1 port=$2 name=$3
|
||||
until (echo > /dev/tcp/"$host"/"$port") 2>/dev/null; do
|
||||
if [ "$SECONDS" -ge 120 ]; then
|
||||
echo "timeout waiting for $name at $host:$port after 120s" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "waiting for $name at $host:$port (elapsed ${SECONDS}s)"
|
||||
sleep 2
|
||||
done
|
||||
echo "$name at $host:$port is up"
|
||||
}
|
||||
wait_for "$DB_HOST" "$DB_PORT" postgres
|
||||
wait_for "$REDIS_HOST" "$REDIS_PORT" redis
|
||||
- dockerize -wait "tcp://$DB_HOST:$DB_PORT" -wait "tcp://$REDIS_HOST:$REDIS_PORT" -timeout 120s
|
||||
resources:
|
||||
limits:
|
||||
memory: "256Mi"
|
||||
@@ -820,26 +764,15 @@ init:
|
||||
# @default -- a container waiting for postgres
|
||||
initContainers:
|
||||
- name: wait-for-postgres
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: "{{ .Values.image.pullPolicy }}"
|
||||
image: "{{ .Values.initImage.repository }}:{{ .Values.initImage.tag }}"
|
||||
imagePullPolicy: "{{ .Values.initImage.pullPolicy }}"
|
||||
envFrom:
|
||||
- secretRef:
|
||||
name: "{{ tpl .Values.envFromSecret . }}"
|
||||
command:
|
||||
- /bin/bash
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
# See supersetNode.initContainers for the rationale.
|
||||
SECONDS=0
|
||||
until (echo > /dev/tcp/"$DB_HOST"/"$DB_PORT") 2>/dev/null; do
|
||||
if [ "$SECONDS" -ge 120 ]; then
|
||||
echo "timeout waiting for postgres at $DB_HOST:$DB_PORT after 120s" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "waiting for postgres at $DB_HOST:$DB_PORT (elapsed ${SECONDS}s)"
|
||||
sleep 2
|
||||
done
|
||||
echo "postgres at $DB_HOST:$DB_PORT is up"
|
||||
- dockerize -wait "tcp://$DB_HOST:$DB_PORT" -timeout 120s
|
||||
resources:
|
||||
limits:
|
||||
memory: "256Mi"
|
||||
|
||||
@@ -39,7 +39,7 @@ dependencies = [
|
||||
"apache-superset-core",
|
||||
"backoff>=1.8.0",
|
||||
"celery>=5.3.6, <6.0.0",
|
||||
"click>=8.0.3",
|
||||
"click>=8.4.0",
|
||||
"click-option-group",
|
||||
"colorama",
|
||||
"flask-cors>=6.0.0, <7.0",
|
||||
@@ -103,7 +103,7 @@ dependencies = [
|
||||
"sqlalchemy-utils>=0.38.0, <0.43", # expanding lowerbound to work with pydoris
|
||||
"sqlglot>=30.8.0, <31",
|
||||
# newer pandas needs 0.9+
|
||||
"tabulate>=0.9.0, <1.0",
|
||||
"tabulate>=0.10.0, <1.0",
|
||||
"typing-extensions>=4, <5",
|
||||
"waitress; sys_platform == 'win32'",
|
||||
"watchdog>=6.0.0",
|
||||
@@ -139,7 +139,7 @@ denodo = ["denodo-sqlalchemy>=1.0.6,<2.1.0"]
|
||||
dremio = ["sqlalchemy-dremio>=1.2.1, <4"]
|
||||
drill = ["sqlalchemy-drill>=1.1.10, <2"]
|
||||
druid = ["pydruid>=0.6.5,<0.7"]
|
||||
duckdb = ["duckdb>=1.4.2,<2", "duckdb-engine>=0.17.0"]
|
||||
duckdb = ["duckdb>=1.5.2,<2", "duckdb-engine>=0.17.0"]
|
||||
dynamodb = ["pydynamodb>=0.4.2"]
|
||||
solr = ["sqlalchemy-solr >= 0.2.0"]
|
||||
elasticsearch = ["elasticsearch-dbapi>=0.2.13, <0.3.0"]
|
||||
|
||||
@@ -60,7 +60,7 @@ cffi==2.0.0
|
||||
# pynacl
|
||||
charset-normalizer==3.4.2
|
||||
# via requests
|
||||
click==8.2.1
|
||||
click==8.4.1
|
||||
# via
|
||||
# apache-superset (pyproject.toml)
|
||||
# celery
|
||||
@@ -421,7 +421,7 @@ sqlglot==30.8.0
|
||||
# apache-superset-core
|
||||
sshtunnel==0.4.0
|
||||
# via apache-superset (pyproject.toml)
|
||||
tabulate==0.9.0
|
||||
tabulate==0.10.0
|
||||
# via apache-superset (pyproject.toml)
|
||||
trio==0.30.0
|
||||
# via
|
||||
|
||||
@@ -130,7 +130,7 @@ charset-normalizer==3.4.2
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# requests
|
||||
click==8.2.1
|
||||
click==8.4.1
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
@@ -219,7 +219,7 @@ docstring-parser==0.17.0
|
||||
# via cyclopts
|
||||
docutils==0.22.2
|
||||
# via rich-rst
|
||||
duckdb==1.4.2
|
||||
duckdb==1.5.3
|
||||
# via
|
||||
# apache-superset
|
||||
# duckdb-engine
|
||||
@@ -1006,7 +1006,7 @@ statsd==4.0.1
|
||||
# via apache-superset
|
||||
syntaqlite==0.1.0
|
||||
# via apache-superset
|
||||
tabulate==0.9.0
|
||||
tabulate==0.10.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
|
||||
@@ -31,6 +31,7 @@ PATTERNS = {
|
||||
r"^superset/",
|
||||
r"^scripts/",
|
||||
r"^setup\.py",
|
||||
r"^pyproject\.toml$",
|
||||
r"^requirements/.+\.txt",
|
||||
r"^.pylintrc",
|
||||
],
|
||||
|
||||
4
superset-frontend/package-lock.json
generated
4
superset-frontend/package-lock.json
generated
@@ -49973,7 +49973,7 @@
|
||||
"acorn": "^8.16.0",
|
||||
"d3-array": "^3.2.4",
|
||||
"lodash": "^4.18.1",
|
||||
"zod": "^4.4.1"
|
||||
"zod": "^4.4.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-superset/core": "*",
|
||||
@@ -50174,7 +50174,7 @@
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@types/d3-scale": "^4.0.9",
|
||||
"d3-cloud": "^1.2.9",
|
||||
"d3-cloud": "^1.2.8",
|
||||
"d3-scale": "^4.0.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -56,7 +56,7 @@
|
||||
"react-js-cron": "^5.2.0",
|
||||
"react-markdown": "^8.0.7",
|
||||
"react-resize-detector": "^7.1.2",
|
||||
"react-syntax-highlighter": "^16.1.1",
|
||||
"react-syntax-highlighter": "^16.1.0",
|
||||
"react-ultimate-pagination": "^1.3.2",
|
||||
"regenerator-runtime": "^0.14.1",
|
||||
"rehype-raw": "^7.0.0",
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
"acorn": "^8.16.0",
|
||||
"d3-array": "^3.2.4",
|
||||
"lodash": "^4.18.1",
|
||||
"zod": "^4.4.1"
|
||||
"zod": "^4.4.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-superset/core": "*",
|
||||
|
||||
@@ -21,8 +21,8 @@ import { TreePathInfo } from '../types';
|
||||
|
||||
export const COLOR_SATURATION = [0.7, 0.4];
|
||||
export const LABEL_FONTSIZE = 11;
|
||||
export const BORDER_WIDTH = 2;
|
||||
export const GAP_WIDTH = 2;
|
||||
export const BORDER_WIDTH = 0;
|
||||
export const GAP_WIDTH = 0;
|
||||
|
||||
export const extractTreePathInfo = (
|
||||
treePathInfo: TreePathInfo[] | undefined,
|
||||
|
||||
@@ -214,7 +214,8 @@ export default function transformProps(
|
||||
colorAlpha: OpacityEnum.SemiTransparent,
|
||||
color: theme.colorText,
|
||||
borderColor: theme.colorBgBase,
|
||||
borderWidth: 2,
|
||||
borderWidth: BORDER_WIDTH,
|
||||
gapWidth: GAP_WIDTH,
|
||||
},
|
||||
label: {
|
||||
...labelProps,
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
*/
|
||||
import { ChartProps } from '@superset-ui/core';
|
||||
import { supersetTheme } from '@apache-superset/core/theme';
|
||||
import { OpacityEnum } from '../../src/constants';
|
||||
import { EchartsTreemapChartProps } from '../../src/Treemap/types';
|
||||
import transformProps from '../../src/Treemap/transformProps';
|
||||
|
||||
@@ -74,4 +75,44 @@ describe('Treemap transformProps', () => {
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test('should not render gaps between treemap nodes when filtered', () => {
|
||||
const filteredChartProps = new ChartProps({
|
||||
...chartProps,
|
||||
filterState: { selectedValues: ['Sylvester,bar1'] },
|
||||
});
|
||||
|
||||
expect(
|
||||
transformProps(filteredChartProps as EchartsTreemapChartProps),
|
||||
).toEqual(
|
||||
expect.objectContaining({
|
||||
echartOptions: expect.objectContaining({
|
||||
series: [
|
||||
expect.objectContaining({
|
||||
data: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
children: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
name: 'Arnold',
|
||||
children: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
name: 'bar2',
|
||||
itemStyle: expect.objectContaining({
|
||||
borderWidth: 0,
|
||||
gapWidth: 0,
|
||||
colorAlpha: OpacityEnum.SemiTransparent,
|
||||
}),
|
||||
label: expect.objectContaining({}),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
],
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
"@math.gl/web-mercator": "^4.1.0",
|
||||
"mapbox-gl": "^3.24.0",
|
||||
"maplibre-gl": "^5.24.0",
|
||||
"react-map-gl": "^8.1.1",
|
||||
"react-map-gl": "^8.1.0",
|
||||
"supercluster": "^8.0.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
||||
@@ -468,9 +468,12 @@ export function saveDashboardRequest(
|
||||
);
|
||||
const cleanedData: JsonObject = {
|
||||
...data,
|
||||
certified_by: certified_by || '',
|
||||
certification_details:
|
||||
certified_by && certification_details ? certification_details : '',
|
||||
...(certified_by !== undefined && {
|
||||
certified_by,
|
||||
certification_details: certified_by
|
||||
? (certification_details ?? '')
|
||||
: '',
|
||||
}),
|
||||
css: css || '',
|
||||
dashboard_title: dashboard_title || t('[ untitled dashboard ]'),
|
||||
owners: ensureIsArray(owners as JsonObject[]).map((o: JsonObject) =>
|
||||
|
||||
@@ -185,6 +185,16 @@ class ExportDashboardsCommand(ExportModelsCommand):
|
||||
# Add theme UUID for proper cross-system imports
|
||||
payload["theme_uuid"] = str(model.theme.uuid) if model.theme else None
|
||||
|
||||
# Include role assignments (DASHBOARD_RBAC). Role IDs are
|
||||
# environment-local, so emit names — the import side resolves them
|
||||
# back to roles in the destination environment. The key is omitted
|
||||
# entirely when there are no role restrictions; older import code
|
||||
# treats "missing" as "no restriction" and an empty list could
|
||||
# confuse importers that distinguish the two states.
|
||||
role_names = sorted(role.name for role in (model.roles or []))
|
||||
if role_names:
|
||||
payload["roles"] = role_names
|
||||
|
||||
payload["version"] = EXPORT_VERSION
|
||||
|
||||
# Check if the TAGGING_SYSTEM feature is enabled
|
||||
|
||||
@@ -281,6 +281,11 @@ def import_dashboard( # noqa: C901
|
||||
|
||||
# Note: theme_id handling moved to higher level import logic
|
||||
|
||||
# Pop roles before handing config to import_from_dict — it's a
|
||||
# relationship, not a column, and the standard SQLAlchemy import path
|
||||
# doesn't resolve role *names* into role objects. We re-attach below.
|
||||
role_names = config.pop("roles", None)
|
||||
|
||||
for key, new_name in JSON_KEYS.items():
|
||||
if config.get(key) is not None:
|
||||
value = config.pop(key)
|
||||
@@ -296,4 +301,25 @@ def import_dashboard( # noqa: C901
|
||||
if (user := get_user()) and user not in dashboard.owners:
|
||||
dashboard.owners.append(user)
|
||||
|
||||
# Re-attach DASHBOARD_RBAC role assignments by name. Role IDs are
|
||||
# environment-local; names are how exports cross environments. Roles
|
||||
# that don't exist in the destination are skipped with a warning
|
||||
# rather than failing the import — admins may need to create them
|
||||
# before the access restriction takes effect.
|
||||
if isinstance(role_names, list) and role_names:
|
||||
resolved_roles = []
|
||||
for name in role_names:
|
||||
role = security_manager.find_role(name)
|
||||
if role is not None:
|
||||
resolved_roles.append(role)
|
||||
else:
|
||||
logger.warning(
|
||||
"Dashboard '%s': role %r referenced in export does not "
|
||||
"exist in this environment; access restriction will not "
|
||||
"be applied for that role",
|
||||
dashboard.dashboard_title,
|
||||
name,
|
||||
)
|
||||
dashboard.roles = resolved_roles
|
||||
|
||||
return dashboard
|
||||
|
||||
@@ -519,6 +519,7 @@ class ImportV1DashboardSchema(Schema):
|
||||
tags = fields.List(fields.String(), allow_none=True)
|
||||
theme_uuid = fields.UUID(allow_none=True)
|
||||
theme_id = fields.Integer(allow_none=True)
|
||||
roles = fields.List(fields.String(), allow_none=True)
|
||||
|
||||
|
||||
class EmbeddedDashboardConfigSchema(Schema):
|
||||
|
||||
@@ -130,20 +130,17 @@ Dashboard Management:
|
||||
- generate_dashboard: Create a dashboard from chart IDs (requires write access)
|
||||
- add_chart_to_existing_dashboard: Add a chart to an existing dashboard (requires write access)
|
||||
|
||||
Database Connections:
|
||||
- list_databases: List database connections with advanced filters (1-based pagination)
|
||||
- get_database_info: Get detailed database connection info by ID (backend, capabilities)
|
||||
|
||||
Dataset Management:
|
||||
- list_datasets: List datasets with advanced filters (1-based pagination)
|
||||
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
|
||||
- create_dataset: Register a physical table as a dataset against an existing DB connection (requires write access)
|
||||
- create_virtual_dataset: Save a SQL query as a virtual dataset for charting (requires write access)
|
||||
- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart
|
||||
|
||||
Chart Management:
|
||||
- list_charts: List charts with advanced filters (1-based pagination)
|
||||
- get_chart_info: Get detailed chart information by ID
|
||||
- get_chart_preview: Get a visual preview of a chart as formatted content or URL
|
||||
- get_chart_preview: Get a visual preview of a chart with image URL
|
||||
- get_chart_data: Get underlying chart data in text-friendly format
|
||||
- get_chart_sql: Get the rendered SQL query for a chart (without executing it)
|
||||
- generate_chart: Create and save a new chart permanently (requires write access)
|
||||
@@ -163,30 +160,25 @@ System Information:
|
||||
- get_instance_info: Get instance-wide statistics, metadata, and current user identity
|
||||
- find_users: Resolve a person's name to user IDs for use as a filter value
|
||||
- health_check: Simple health check tool (takes NO parameters, call without arguments)
|
||||
- generate_bug_report: Build a PII-sanitized bug report to send to Preset support
|
||||
(use when the user says the MCP is broken or asks how to report an issue)
|
||||
|
||||
Available Resources:
|
||||
- instance://metadata: Instance configuration, stats, and available dataset IDs
|
||||
- chart://configs: Valid chart configuration examples and best practices
|
||||
- instance/metadata: Access instance configuration and metadata
|
||||
- chart/templates: Access chart configuration templates
|
||||
|
||||
Available Prompts:
|
||||
- quickstart: Interactive guide for getting started with the MCP service
|
||||
- create_chart_guided: Step-by-step chart creation wizard
|
||||
|
||||
IMPORTANT - Using Saved Metrics vs Columns:
|
||||
When get_dataset_info returns a dataset, it includes both 'columns' and 'metrics'.
|
||||
- 'columns' are raw database columns (e.g., order_date, product_name, revenue)
|
||||
- 'metrics' are pre-defined saved metrics with SQL expressions
|
||||
(e.g., count, total_revenue)
|
||||
Common Chart Types (viz_type) and Behaviors:
|
||||
|
||||
When building chart configurations
|
||||
(generate_chart, generate_explore_link, update_chart):
|
||||
- For raw columns: use {{"name": "col_name", "aggregate": "SUM"}}
|
||||
- For saved metrics: use {{"name": "metric", "saved_metric": true}}
|
||||
Do NOT add an aggregate when using saved_metric=true
|
||||
(it's already defined in the metric).
|
||||
Do NOT use a saved metric name as if it were a column — it will fail.
|
||||
Interactive Charts (support sorting, filtering, drill-down):
|
||||
- table: Standard table view with sorting and filtering
|
||||
- pivot_table_v2: Pivot table with grouping and aggregations
|
||||
- echarts_timeseries_line: Time series line chart
|
||||
- echarts_timeseries_bar: Time series bar chart
|
||||
- echarts_timeseries_area: Time series area chart
|
||||
- echarts_timeseries_scatter: Time series scatter plot
|
||||
- mixed_timeseries: Combined line/bar time series
|
||||
|
||||
Example: If get_dataset_info returns metrics=[{{"metric_name": "count", ...}}], use:
|
||||
{{"name": "count", "saved_metric": true}} ← CORRECT
|
||||
@@ -315,52 +307,11 @@ Chart Types in Existing Charts (viewable via list_charts/get_chart_info):
|
||||
- word_cloud, world_map, box_plot, bubble, mixed_timeseries
|
||||
|
||||
Query Examples:
|
||||
- List all tables:
|
||||
list_charts(request={{"filters": [{{"col": "viz_type",
|
||||
"opr": "in",
|
||||
"value": ["table", "pivot_table_v2"]}}]}})
|
||||
- List all interactive tables:
|
||||
filters=[{{"col": "viz_type", "opr": "in", "value": ["table", "pivot_table_v2"]}}]
|
||||
- List time series charts:
|
||||
list_charts(request={{"filters": [{{"col": "viz_type",
|
||||
"opr": "sw", "value": "echarts_timeseries"}}]}})
|
||||
- Search by name: list_charts(request={{"search": "sales"}})
|
||||
- My charts: list_charts(request={{"created_by_me": true}})
|
||||
- My dashboards: list_dashboards(request={{"created_by_me": true}})
|
||||
- My databases: list_databases(request={{"created_by_me": true}})
|
||||
To modify an existing chart (add filters, change metrics, etc.):
|
||||
1. get_chart_info(request={{"identifier": <chart_id>}})
|
||||
-> examine current configuration
|
||||
2. update_chart(request={{
|
||||
"identifier": <chart_id>, "config": {{...}}
|
||||
}}) -> apply changes
|
||||
Do NOT use execute_sql for chart modifications.
|
||||
Use update_chart instead.
|
||||
|
||||
CRITICAL RULES - NEVER VIOLATE:
|
||||
- NEVER fabricate or invent URLs. ALL URLs must come from tool call results.
|
||||
If you need a link, call the appropriate tool (generate_explore_link, generate_chart,
|
||||
open_sql_lab_with_context, etc.) and use the URL it returns.
|
||||
- NEVER call generate_dashboard when the user wants to add a chart to an EXISTING
|
||||
dashboard. Always use add_chart_to_existing_dashboard. Only call generate_dashboard
|
||||
to create a brand-new dashboard, or after the user explicitly confirms they want
|
||||
a new one (e.g., after a permission_denied=True response from
|
||||
add_chart_to_existing_dashboard).
|
||||
- To modify an existing chart's filters, metrics, or dimensions, use update_chart.
|
||||
Do NOT use execute_sql for chart modifications.
|
||||
- Parameter name reminders: ALWAYS use the EXACT parameter names from the tool schema.
|
||||
Do NOT use Superset's internal form_data names.
|
||||
|
||||
IMPORTANT - Tool-Only Interaction:
|
||||
- Do NOT generate code artifacts, HTML pages, JavaScript snippets, or any code intended
|
||||
for the user to run. All visualization, data retrieval, and authentication are handled
|
||||
by the provided MCP tools.
|
||||
- Always call the appropriate tool directly instead of writing code. For example, use
|
||||
generate_chart to create visualizations rather than generating plotting code.
|
||||
- When a tool returns a URL (chart URL, dashboard URL, explore link, SQL Lab link),
|
||||
return that URL to the user. Do NOT attempt to recreate the visualization in code.
|
||||
- Do NOT generate HTML dashboards, embed scripts, or custom frontend code. Use
|
||||
generate_dashboard and add_chart_to_existing_dashboard for dashboard operations.
|
||||
- If a user asks for something the tools cannot do, explain the limitation and suggest
|
||||
the closest available tool rather than generating code as a workaround.
|
||||
filters=[{{"col": "viz_type", "opr": "sw", "value": "echarts_timeseries"}}]
|
||||
- Search by name: search="sales"
|
||||
|
||||
General usage tips:
|
||||
- All listing tools use 1-based pagination (first page is 1)
|
||||
@@ -368,7 +319,7 @@ General usage tips:
|
||||
- Use 'filters' parameter for advanced queries with filter columns from get_schema
|
||||
- IDs can be integer or UUID format where supported
|
||||
- All tools return structured, Pydantic-typed responses
|
||||
- Chart previews can return ASCII text, Explore URLs, table data, or Vega-Lite specs
|
||||
- Chart previews are served as PNG images via custom screenshot endpoints
|
||||
|
||||
Input format:
|
||||
- Tool request parameters accept structured objects (dicts/JSON)
|
||||
@@ -377,10 +328,11 @@ Input format:
|
||||
{_feature_availability}Permission Awareness:
|
||||
{_instance_info_role_bullet}- ALWAYS check the user's roles BEFORE suggesting write operations (creating datasets,
|
||||
charts, or dashboards). SQL execution is a separate permission — see execute_sql below.
|
||||
- Write tools (generate_chart, generate_dashboard, update_chart, create_virtual_dataset,
|
||||
save_sql_query, add_chart_to_existing_dashboard, update_chart_preview) require write
|
||||
permissions. These tools are only listed for users who have the necessary access.
|
||||
If a write tool does not appear in the tool list, the current user lacks write access.
|
||||
- Write tools (generate_chart, generate_dashboard, update_chart, create_dataset,
|
||||
create_virtual_dataset, save_sql_query, add_chart_to_existing_dashboard,
|
||||
update_chart_preview) require write permissions. These tools are only listed for
|
||||
users who have the necessary access. If a write tool does not appear in the tool
|
||||
list, the current user lacks write access.
|
||||
- execute_sql requires SQL Lab access (execute_sql_query permission), which is separate
|
||||
from write access. A user may have SQL Lab access without having write access to charts
|
||||
or dashboards, and vice versa.
|
||||
@@ -584,39 +536,13 @@ def create_mcp_app(
|
||||
|
||||
|
||||
# Create default MCP instance for backward compatibility
|
||||
# Tool modules can import this and use @mcp.tool decorators
|
||||
mcp = create_mcp_app()
|
||||
|
||||
# Initialize MCP dependency injection BEFORE importing tools/prompts
|
||||
# This replaces the abstract @tool and @prompt decorators in superset_core.api.mcp
|
||||
# with concrete implementations that can register with the mcp instance
|
||||
from superset.core.mcp.core_mcp_injection import ( # noqa: E402
|
||||
initialize_core_mcp_dependencies,
|
||||
)
|
||||
|
||||
initialize_core_mcp_dependencies()
|
||||
|
||||
# Suppress known third-party deprecation warnings that leak to MCP clients.
|
||||
# The MCP SDK captures Python warnings and forwards them to clients via
|
||||
# server log entries, wasting LLM tokens and causing clients to act on
|
||||
# irrelevant internal warnings. These warnings come from transitive imports
|
||||
# triggered by tool/schema registration below.
|
||||
import warnings # noqa: E402
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=DeprecationWarning,
|
||||
module=r"marshmallow\..*",
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=FutureWarning,
|
||||
module=r"google\..*",
|
||||
)
|
||||
|
||||
# Import all MCP tools to register them with the mcp instance
|
||||
# NOTE: Always add new tool imports here when creating new MCP tools.
|
||||
# Tools use the @tool decorator from `superset-core` and register automatically
|
||||
# on import. Import prompts and resources to register them with the mcp instance
|
||||
# Tools use @mcp.tool decorators and register automatically on import.
|
||||
# Import prompts and resources to register them with the mcp instance
|
||||
# NOTE: Always add new prompt/resource imports here when creating new prompts/resources.
|
||||
# Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators.
|
||||
# They register automatically on import, similar to tools.
|
||||
@@ -646,6 +572,7 @@ from superset.mcp_service.database.tool import ( # noqa: F401, E402
|
||||
list_databases,
|
||||
)
|
||||
from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||
create_dataset,
|
||||
create_virtual_dataset,
|
||||
get_dataset_info,
|
||||
list_datasets,
|
||||
|
||||
@@ -487,7 +487,17 @@ class AddChartToDashboardRequest(BaseModel):
|
||||
)
|
||||
chart_id: int = Field(..., description="ID of the chart to add to the dashboard")
|
||||
target_tab: str | None = Field(
|
||||
None, description="Target tab name (if dashboard has tabs)"
|
||||
None,
|
||||
min_length=1,
|
||||
description=(
|
||||
"Tab to place the chart in. Accepts a tab display name "
|
||||
"(e.g. 'Sales') or a tab component ID (e.g. 'TAB-abc123'). "
|
||||
"Display-name matching is case-insensitive and strips all emoji; "
|
||||
"component ID matching is case-sensitive and exact. "
|
||||
"When not found, the error response lists all available tab names. "
|
||||
"When omitted on a tabbed dashboard the chart is placed in the "
|
||||
"first tab."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -514,6 +524,19 @@ class AddChartToDashboardResponse(BaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("error")
|
||||
@classmethod
|
||||
def sanitize_error_for_llm_context(cls, value: str | None) -> str | None:
|
||||
"""Wrap error text before it is exposed to LLM context.
|
||||
|
||||
The error may echo user-supplied target_tab or dashboard-controlled tab
|
||||
labels — both must be wrapped so the LLM treats them as data, not
|
||||
instructions.
|
||||
"""
|
||||
if value is None:
|
||||
return value
|
||||
return sanitize_for_llm_context(value, field_path=("error",))
|
||||
|
||||
|
||||
class GenerateDashboardRequest(BaseModel):
|
||||
"""Request schema for generating a dashboard."""
|
||||
|
||||
@@ -149,29 +149,50 @@ def _first_tab_from_groups(
|
||||
return None
|
||||
|
||||
|
||||
def _collect_available_tab_names(layout: Dict[str, Any]) -> list[str]:
|
||||
"""Collect display entries (label + component ID) for all TAB components.
|
||||
|
||||
Always includes the component ID so callers can retry unambiguously even
|
||||
when multiple tabs share the same display name or a label is blank.
|
||||
"""
|
||||
entries: list[str] = []
|
||||
for tabs_children in _collect_tabs_groups(layout):
|
||||
for tab_id in tabs_children:
|
||||
tab = layout.get(tab_id)
|
||||
if not tab or tab.get("type") != "TAB":
|
||||
continue
|
||||
text = (tab.get("meta") or {}).get("text", "")
|
||||
entries.append(f"{text} ({tab_id})" if text else tab_id)
|
||||
return entries
|
||||
|
||||
|
||||
def _find_tab_insert_target(
|
||||
layout: Dict[str, Any], target_tab: str | None = None
|
||||
) -> str | None:
|
||||
"""
|
||||
Detect if the dashboard uses tabs and return the appropriate tab's ID.
|
||||
|
||||
If *target_tab* is provided the function first tries to match it against
|
||||
tab ``meta.text`` (display name) or the raw component ID. When no match
|
||||
is found (or *target_tab* is ``None``) the first ``TAB`` child is used as
|
||||
a fallback so that new rows are still placed inside the tab structure
|
||||
rather than directly under ``GRID_ID``.
|
||||
When *target_tab* is ``None`` the function returns the first TAB child so
|
||||
that new rows are placed inside the tab structure rather than directly
|
||||
under ``GRID_ID``.
|
||||
|
||||
When *target_tab* is provided the function tries to match it against tab
|
||||
``meta.text`` (display name) or the raw component ID. If no match is
|
||||
found ``None`` is returned — the caller is responsible for surfacing an
|
||||
error rather than silently placing the chart in the wrong tab.
|
||||
|
||||
Returns:
|
||||
The ID of the matched (or first) TAB component, or ``None`` if the
|
||||
dashboard does not use top-level tabs.
|
||||
The ID of the matched (or first) TAB component, or ``None``.
|
||||
"""
|
||||
groups = _collect_tabs_groups(layout)
|
||||
|
||||
if target_tab:
|
||||
if target_tab is not None:
|
||||
for tabs_children in groups:
|
||||
matched = _match_tab_in_children(layout, tabs_children, target_tab)
|
||||
if matched:
|
||||
return matched
|
||||
# target_tab specified but not found — signal mismatch to the caller.
|
||||
return None
|
||||
|
||||
return _first_tab_from_groups(layout, groups)
|
||||
|
||||
@@ -316,6 +337,45 @@ def _ensure_layout_structure(
|
||||
layout["DASHBOARD_VERSION_KEY"] = "v2"
|
||||
|
||||
|
||||
def _resolve_parent_container(
|
||||
layout: Dict[str, Any],
|
||||
dashboard_id: int,
|
||||
target_tab: str | None,
|
||||
) -> tuple[str, None] | tuple[None, AddChartToDashboardResponse]:
|
||||
"""Return (parent_id, None) on success or (None, error_response) on mismatch.
|
||||
|
||||
When *target_tab* is specified and not found the caller receives a
|
||||
descriptive error listing available tabs rather than a silent fallback.
|
||||
"""
|
||||
tab_target = _find_tab_insert_target(layout, target_tab=target_tab)
|
||||
|
||||
if target_tab is not None and tab_target is None:
|
||||
available = _collect_available_tab_names(layout)
|
||||
if available:
|
||||
tab_list = ", ".join(available)
|
||||
return None, AddChartToDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
position=None,
|
||||
error=(
|
||||
f"Tab '{target_tab}' not found in dashboard {dashboard_id}. "
|
||||
f"Available tabs: {tab_list}."
|
||||
),
|
||||
)
|
||||
return None, AddChartToDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
position=None,
|
||||
error=(
|
||||
f"Dashboard {dashboard_id} has no tabs. "
|
||||
"Remove the target_tab parameter to add the chart to "
|
||||
"the default grid layout."
|
||||
),
|
||||
)
|
||||
|
||||
return (tab_target if tab_target else "GRID_ID", None)
|
||||
|
||||
|
||||
def _find_and_authorize_dashboard(
|
||||
dashboard_id: int,
|
||||
) -> tuple[Any, AddChartToDashboardResponse | None]:
|
||||
@@ -369,7 +429,7 @@ def _find_and_authorize_dashboard(
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
def add_chart_to_existing_dashboard(
|
||||
def add_chart_to_existing_dashboard( # noqa: C901 — complexity is structural (layout traversal + multi-step authorization), not accidental
|
||||
request: AddChartToDashboardRequest, ctx: Context
|
||||
) -> AddChartToDashboardResponse:
|
||||
"""
|
||||
@@ -443,11 +503,16 @@ def add_chart_to_existing_dashboard(
|
||||
# Generate a unique ROW ID for the new row
|
||||
row_key = _find_next_row_position(current_layout)
|
||||
|
||||
# Detect tabbed dashboards and resolve target_tab by name or ID
|
||||
tab_target = _find_tab_insert_target(
|
||||
current_layout, target_tab=request.target_tab
|
||||
# Detect tabbed dashboards and resolve target_tab by name or ID.
|
||||
parent_id, tab_error = _resolve_parent_container(
|
||||
current_layout, request.dashboard_id, request.target_tab
|
||||
)
|
||||
parent_id = tab_target if tab_target else "GRID_ID"
|
||||
if tab_error is not None:
|
||||
return tab_error
|
||||
if parent_id is None:
|
||||
raise RuntimeError(
|
||||
"unreachable: tab_error is None implies parent_id is str"
|
||||
)
|
||||
|
||||
# Add chart, column, and row to layout
|
||||
chart_key, column_key, row_key = _add_chart_to_layout(
|
||||
|
||||
@@ -324,6 +324,37 @@ class GetDatasetInfoRequest(MetadataCacheControl):
|
||||
]
|
||||
|
||||
|
||||
class CreateDatasetRequest(BaseModel):
|
||||
"""Request schema for create_dataset to register a physical table as a dataset."""
|
||||
|
||||
database_id: Annotated[
|
||||
int,
|
||||
Field(
|
||||
description="ID of the database connection to register the table against"
|
||||
),
|
||||
]
|
||||
schema: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Schema (namespace) where the table lives, e.g. 'public'. "
|
||||
"Optional: omit to use the database default schema.",
|
||||
),
|
||||
]
|
||||
table_name: Annotated[
|
||||
str,
|
||||
Field(description="Name of the physical table to register as a dataset"),
|
||||
]
|
||||
owners: Annotated[
|
||||
List[int] | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Optional list of owner user IDs. "
|
||||
"Defaults to the calling user.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class CreateVirtualDatasetRequest(BaseModel):
|
||||
"""Request schema for create_virtual_dataset."""
|
||||
|
||||
|
||||
@@ -15,14 +15,16 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .create_dataset import create_dataset
|
||||
from .create_virtual_dataset import create_virtual_dataset
|
||||
from .get_dataset_info import get_dataset_info
|
||||
from .list_datasets import list_datasets
|
||||
from .query_dataset import query_dataset
|
||||
|
||||
__all__ = [
|
||||
"create_dataset",
|
||||
"create_virtual_dataset",
|
||||
"list_datasets",
|
||||
"get_dataset_info",
|
||||
"list_datasets",
|
||||
"query_dataset",
|
||||
]
|
||||
|
||||
142
superset/mcp_service/dataset/tool/create_dataset.py
Normal file
142
superset/mcp_service/dataset/tool/create_dataset.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Create dataset FastMCP tool
|
||||
|
||||
Registers a physical table as a Superset dataset against an existing
|
||||
database connection — the programmatic equivalent of Data → Datasets → +Dataset.
|
||||
Returns the same DatasetInfo shape as get_dataset_info so the caller can feed
|
||||
the resulting dataset_id directly into generate_chart.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.dataset.schemas import (
|
||||
CreateDatasetRequest,
|
||||
DatasetError,
|
||||
DatasetInfo,
|
||||
serialize_dataset_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["mutate"],
|
||||
class_permission_name="Dataset",
|
||||
method_permission_name="write",
|
||||
annotations=ToolAnnotations(
|
||||
title="Create dataset",
|
||||
readOnlyHint=False,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def create_dataset(
|
||||
request: CreateDatasetRequest, ctx: Context
|
||||
) -> DatasetInfo | DatasetError:
|
||||
"""Register a physical table as a Superset dataset.
|
||||
|
||||
Wraps POST /api/v1/dataset/ — the same endpoint the UI uses when you click
|
||||
Data → Datasets → +Dataset. Returns full dataset metadata (same shape as
|
||||
get_dataset_info) so you can pass the resulting dataset_id straight into
|
||||
generate_chart.
|
||||
|
||||
Required fields:
|
||||
- database_id: ID of the existing database connection
|
||||
- table_name: Exact name of the physical table to register
|
||||
|
||||
Optional fields:
|
||||
- schema: Schema/namespace where the table lives (e.g. "public")
|
||||
- owners: List of user IDs to set as owners (defaults to calling user)
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders"
|
||||
}
|
||||
```
|
||||
|
||||
Returns DatasetInfo on success or DatasetError on failure.
|
||||
Use list_databases to find the correct database_id.
|
||||
"""
|
||||
await ctx.info(
|
||||
"Creating dataset: database_id=%s, schema=%r, table_name=%r"
|
||||
% (request.database_id, request.schema, request.table_name)
|
||||
)
|
||||
try:
|
||||
from superset.commands.dataset.create import CreateDatasetCommand
|
||||
from superset.commands.dataset.exceptions import (
|
||||
DatasetCreateFailedError,
|
||||
DatasetExistsValidationError,
|
||||
DatasetInvalidError,
|
||||
TableNotFoundValidationError,
|
||||
)
|
||||
|
||||
dataset_properties: dict[str, Any] = {
|
||||
"database": request.database_id,
|
||||
"schema": request.schema,
|
||||
"table_name": request.table_name,
|
||||
}
|
||||
if request.owners is not None:
|
||||
dataset_properties["owners"] = request.owners
|
||||
|
||||
with event_logger.log_context(action="mcp.create_dataset"):
|
||||
command = CreateDatasetCommand(dataset_properties)
|
||||
dataset = command.run()
|
||||
|
||||
result = serialize_dataset_object(dataset)
|
||||
if result is None:
|
||||
return DatasetError.create(
|
||||
error="Dataset was created but could not be serialized",
|
||||
error_type="SerializationError",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created dataset id=%s table=%s.%s",
|
||||
dataset.id,
|
||||
request.schema,
|
||||
request.table_name,
|
||||
)
|
||||
return result
|
||||
|
||||
except DatasetExistsValidationError as e:
|
||||
await ctx.error("Dataset already exists: %s" % (str(e),))
|
||||
return DatasetError.create(error=str(e), error_type="DatasetExistsError")
|
||||
except TableNotFoundValidationError as e:
|
||||
await ctx.error("Table not found: %s" % (str(e),))
|
||||
return DatasetError.create(error=str(e), error_type="TableNotFoundError")
|
||||
except DatasetInvalidError as e:
|
||||
await ctx.error("Dataset validation failed: %s" % (str(e),))
|
||||
return DatasetError.create(error=str(e), error_type="ValidationError")
|
||||
except DatasetCreateFailedError as e:
|
||||
await ctx.error("Dataset creation failed: %s" % (str(e),))
|
||||
return DatasetError.create(error=str(e), error_type="CreateFailedError")
|
||||
except Exception as e:
|
||||
logger.error("Failed to create dataset: %s", e, exc_info=True)
|
||||
await ctx.error("Unexpected error: %s: %s" % (type(e).__name__, str(e)))
|
||||
return DatasetError.create(
|
||||
error=f"Failed to create dataset: {str(e)}",
|
||||
error_type="InternalError",
|
||||
)
|
||||
@@ -188,10 +188,15 @@ def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Remove sensitive fields from params before logging."""
|
||||
if not isinstance(params, dict):
|
||||
return params
|
||||
return {
|
||||
k: "[REDACTED]" if k.lower() in _SENSITIVE_PARAM_KEYS else v
|
||||
for k, v in params.items()
|
||||
}
|
||||
result: dict[str, Any] = {}
|
||||
for k, v in params.items():
|
||||
if k.lower() in _SENSITIVE_PARAM_KEYS:
|
||||
result[k] = "[REDACTED]"
|
||||
elif k == "arguments" and isinstance(v, dict):
|
||||
result[k] = _sanitize_params(v)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
class LoggingMiddleware(Middleware):
|
||||
@@ -204,8 +209,17 @@ class LoggingMiddleware(Middleware):
|
||||
Tool calls are handled in on_call_tool() which wraps execution to capture
|
||||
duration_ms. Non-tool messages (resource reads, prompts, etc.) are handled
|
||||
in on_message().
|
||||
|
||||
When tool search is enabled (progressive discovery), the MCP client calls
|
||||
``call_tool`` proxies instead of individual tools. This middleware resolves
|
||||
the underlying tool name from ``call_tool`` arguments so that analytics
|
||||
queries can filter by the actual tool (stored as ``mcp_tool`` in the curated
|
||||
payload).
|
||||
"""
|
||||
|
||||
#: Proxy name used by FastMCP tool-search transforms.
|
||||
_CALL_TOOL_PROXY = "call_tool"
|
||||
|
||||
def _is_error_response(self, result: ToolResult) -> bool:
|
||||
"""Check if a tool result contains an error schema response.
|
||||
|
||||
@@ -244,6 +258,28 @@ class LoggingMiddleware(Middleware):
|
||||
dataset_id = params.get("dataset_id")
|
||||
return agent_id, user_id, dashboard_id, slice_id, dataset_id, params
|
||||
|
||||
@staticmethod
|
||||
def _resolve_tool_name(tool_name: str | None, params: Any) -> str | None:
|
||||
"""Resolve the underlying tool name from call_tool proxy arguments.
|
||||
|
||||
When tool search is enabled, the MCP client uses the ``call_tool``
|
||||
proxy and passes the real tool name as the ``name`` argument. This
|
||||
helper extracts that value so we can log which tool was actually
|
||||
executed rather than just ``"call_tool"``.
|
||||
|
||||
Returns:
|
||||
The resolved tool name if *tool_name* is the call_tool proxy and
|
||||
``params["name"]`` is a non-empty string, otherwise ``None``.
|
||||
"""
|
||||
if (
|
||||
tool_name == LoggingMiddleware._CALL_TOOL_PROXY
|
||||
and isinstance(params, dict)
|
||||
and isinstance(params.get("name"), str)
|
||||
and params["name"]
|
||||
):
|
||||
return params["name"]
|
||||
return None
|
||||
|
||||
async def on_call_tool(
|
||||
self,
|
||||
context: MiddlewareContext,
|
||||
@@ -254,11 +290,13 @@ class LoggingMiddleware(Middleware):
|
||||
self._extract_context_info(context)
|
||||
)
|
||||
tool_name = getattr(context.message, "name", None)
|
||||
mcp_tool = self._resolve_tool_name(tool_name, params)
|
||||
|
||||
mcp_call_id = secrets.token_hex(16)
|
||||
_mcp_call_id_var.set(mcp_call_id)
|
||||
start_time = time.time()
|
||||
success = False
|
||||
error_type: str | None = None
|
||||
try:
|
||||
result = await call_next(context)
|
||||
success = not self._is_error_response(result)
|
||||
@@ -270,11 +308,27 @@ class LoggingMiddleware(Middleware):
|
||||
structured_content=result.structured_content,
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
error_type = type(exc).__name__
|
||||
success = False
|
||||
raise
|
||||
finally:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
payload: dict[str, Any] = {
|
||||
"mcp_call_id": mcp_call_id,
|
||||
"tool": tool_name,
|
||||
"agent_id": agent_id,
|
||||
"params": _sanitize_params(params),
|
||||
"method": context.method,
|
||||
"dashboard_id": dashboard_id,
|
||||
"slice_id": slice_id,
|
||||
"dataset_id": dataset_id,
|
||||
"success": success,
|
||||
}
|
||||
if mcp_tool is not None:
|
||||
payload["mcp_tool"] = mcp_tool
|
||||
if error_type is not None:
|
||||
payload["error_type"] = error_type
|
||||
if has_app_context():
|
||||
event_logger.log(
|
||||
user_id=user_id,
|
||||
@@ -283,22 +337,18 @@ class LoggingMiddleware(Middleware):
|
||||
duration_ms=duration_ms,
|
||||
slice_id=slice_id,
|
||||
referrer=None,
|
||||
curated_payload={
|
||||
"mcp_call_id": mcp_call_id,
|
||||
"tool": tool_name,
|
||||
"agent_id": agent_id,
|
||||
"params": _sanitize_params(params),
|
||||
"method": context.method,
|
||||
"dashboard_id": dashboard_id,
|
||||
"slice_id": slice_id,
|
||||
"dataset_id": dataset_id,
|
||||
"success": success,
|
||||
},
|
||||
curated_payload=payload,
|
||||
)
|
||||
extra_parts = []
|
||||
if mcp_tool is not None:
|
||||
extra_parts.append(f"mcp_tool={mcp_tool}")
|
||||
if error_type is not None:
|
||||
extra_parts.append(f"error_type={error_type}")
|
||||
extra = (", " + ", ".join(extra_parts)) if extra_parts else ""
|
||||
logger.info(
|
||||
"MCP tool call: tool=%s, agent_id=%s, user_id=%s, method=%s, "
|
||||
"dashboard_id=%s, slice_id=%s, dataset_id=%s, duration_ms=%s, "
|
||||
"success=%s, mcp_call_id=%s",
|
||||
"success=%s, mcp_call_id=%s%s",
|
||||
tool_name,
|
||||
agent_id,
|
||||
user_id,
|
||||
@@ -309,6 +359,7 @@ class LoggingMiddleware(Middleware):
|
||||
duration_ms,
|
||||
success,
|
||||
mcp_call_id,
|
||||
extra,
|
||||
)
|
||||
|
||||
async def on_message(
|
||||
@@ -388,7 +439,14 @@ class StructuredContentStripperMiddleware(Middleware):
|
||||
context: MiddlewareContext[mt.ListToolsRequest],
|
||||
call_next: CallNext[mt.ListToolsRequest, Sequence[Tool]],
|
||||
) -> Sequence[Tool]:
|
||||
tools = await call_next(context)
|
||||
try:
|
||||
tools = await call_next(context)
|
||||
except Exception:
|
||||
# ToolError raised by inner middleware (e.g. GlobalErrorHandlerMiddleware)
|
||||
# cannot be encoded by the MCP SDK in a tools/list response — it expects a
|
||||
# list, not an error object — causing "encoding without a string argument".
|
||||
# Return an empty list; GlobalErrorHandlerMiddleware already logged it.
|
||||
return []
|
||||
return [
|
||||
t.model_copy(update={"output_schema": None})
|
||||
if t.output_schema is not None
|
||||
|
||||
@@ -395,47 +395,66 @@ class TestSqlaTableModel(SupersetTestCase):
|
||||
def test_get_timestamp_expression(self):
|
||||
tbl = self.get_table(name="birth_names")
|
||||
ds_col = tbl.get_column("ds")
|
||||
sqla_literal = ds_col.get_timestamp_expression(None)
|
||||
assert str(sqla_literal.compile()) == "ds"
|
||||
try:
|
||||
sqla_literal = ds_col.get_timestamp_expression(None)
|
||||
assert str(sqla_literal.compile()) == "ds"
|
||||
|
||||
sqla_literal = ds_col.get_timestamp_expression("P1D")
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
assert compiled == "DATE(ds)"
|
||||
sqla_literal = ds_col.get_timestamp_expression("P1D")
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
assert compiled == "DATE(ds)"
|
||||
|
||||
prev_ds_expr = ds_col.expression
|
||||
ds_col.expression = "DATE_ADD(ds, 1)"
|
||||
sqla_literal = ds_col.get_timestamp_expression("P1D")
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
assert compiled == "DATE(DATE_ADD(ds, 1))"
|
||||
ds_col.expression = prev_ds_expr
|
||||
prev_ds_expr = ds_col.expression
|
||||
ds_col.expression = "DATE_ADD(ds, 1)"
|
||||
sqla_literal = ds_col.get_timestamp_expression("P1D")
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
assert compiled == "DATE(DATE_ADD(ds, 1))"
|
||||
ds_col.expression = prev_ds_expr
|
||||
finally:
|
||||
# Discard the in-memory attribute history so the next session
|
||||
# autoflush doesn't see this row as dirty. The test only
|
||||
# exercises the in-memory compile path; any persisted write
|
||||
# would be accidental. ``rollback`` rather than ``expire`` —
|
||||
# the latter doesn't reliably clear SA's per-attribute history
|
||||
# tracking for already-loaded objects.
|
||||
metadata_db.session.rollback()
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_get_timestamp_expression_epoch(self):
|
||||
tbl = self.get_table(name="birth_names")
|
||||
ds_col = tbl.get_column("ds")
|
||||
|
||||
ds_col.expression = None
|
||||
ds_col.python_date_format = "epoch_s"
|
||||
sqla_literal = ds_col.get_timestamp_expression(None)
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
assert compiled == "from_unixtime(ds)"
|
||||
try:
|
||||
ds_col.expression = None
|
||||
ds_col.python_date_format = "epoch_s"
|
||||
sqla_literal = ds_col.get_timestamp_expression(None)
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
assert compiled == "from_unixtime(ds)"
|
||||
|
||||
ds_col.python_date_format = "epoch_s"
|
||||
sqla_literal = ds_col.get_timestamp_expression("P1D")
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
assert compiled == "DATE(from_unixtime(ds))"
|
||||
ds_col.python_date_format = "epoch_s"
|
||||
sqla_literal = ds_col.get_timestamp_expression("P1D")
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
assert compiled == "DATE(from_unixtime(ds))"
|
||||
|
||||
prev_ds_expr = ds_col.expression
|
||||
ds_col.expression = "DATE_ADD(ds, 1)"
|
||||
sqla_literal = ds_col.get_timestamp_expression("P1D")
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
assert compiled == "DATE(from_unixtime(DATE_ADD(ds, 1)))"
|
||||
ds_col.expression = prev_ds_expr
|
||||
prev_ds_expr = ds_col.expression
|
||||
ds_col.expression = "DATE_ADD(ds, 1)"
|
||||
sqla_literal = ds_col.get_timestamp_expression("P1D")
|
||||
compiled = f"{sqla_literal.compile()}"
|
||||
if tbl.database.backend == "mysql":
|
||||
assert compiled == "DATE(from_unixtime(DATE_ADD(ds, 1)))"
|
||||
ds_col.expression = prev_ds_expr
|
||||
finally:
|
||||
# Discard the in-memory attribute history so the next session
|
||||
# autoflush doesn't see this row as dirty —
|
||||
# ``python_date_format`` isn't restored above and the test
|
||||
# never commits, so the mutation would otherwise leak.
|
||||
# ``rollback`` rather than ``expire`` — the latter doesn't
|
||||
# reliably clear SA's per-attribute history tracking for
|
||||
# already-loaded objects.
|
||||
metadata_db.session.rollback()
|
||||
|
||||
def query_with_expr_helper(self, is_timeseries, inner_join=True):
|
||||
tbl = self.get_table(name="birth_names")
|
||||
|
||||
@@ -37,9 +37,9 @@ def test_pivot_df_no_cols_no_rows_single_metric():
|
||||
assert (
|
||||
df.to_markdown()
|
||||
== """
|
||||
| | SUM(num) |
|
||||
|---:|------------:|
|
||||
| 0 | 8.06797e+07 |
|
||||
| | SUM(num) |
|
||||
|---:|-----------:|
|
||||
| 0 | 80679663 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_pivot_df_no_cols_no_rows_single_metric():
|
||||
== f"""
|
||||
| | ('SUM(num)',) |
|
||||
|:-----------------|----------------:|
|
||||
| ('{_("Total")} (Sum)',) | 8.06797e+07 |
|
||||
| ('{_("Total")} (Sum)',) | 80679663 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -82,7 +82,7 @@ def test_pivot_df_no_cols_no_rows_single_metric():
|
||||
== """
|
||||
| | ('SUM(num)',) |
|
||||
|:-----------------|----------------:|
|
||||
| ('Total (Sum)',) | 8.06797e+07 |
|
||||
| ('Total (Sum)',) | 80679663 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -105,7 +105,7 @@ def test_pivot_df_no_cols_no_rows_single_metric():
|
||||
== f"""
|
||||
| | ('{_("Total")} (Sum)',) |
|
||||
|:--------------|-------------------:|
|
||||
| ('SUM(num)',) | 8.06797e+07 |
|
||||
| ('SUM(num)',) | 80679663 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -127,7 +127,7 @@ def test_pivot_df_no_cols_no_rows_single_metric():
|
||||
== f"""
|
||||
| | ('SUM(num)',) | ('Total (Sum)',) |
|
||||
|:-----------------|----------------:|-------------------:|
|
||||
| ('{_("Total")} (Sum)',) | 8.06797e+07 | 8.06797e+07 |
|
||||
| ('{_("Total")} (Sum)',) | 80679663 | 80679663 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -142,9 +142,9 @@ def test_pivot_df_no_cols_no_rows_two_metrics():
|
||||
assert (
|
||||
df.to_markdown()
|
||||
== """
|
||||
| | SUM(num) | MAX(num) |
|
||||
|---:|------------:|-----------:|
|
||||
| 0 | 8.06797e+07 | 37296 |
|
||||
| | SUM(num) | MAX(num) |
|
||||
|---:|-----------:|-----------:|
|
||||
| 0 | 80679663 | 37296 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -165,7 +165,7 @@ def test_pivot_df_no_cols_no_rows_two_metrics():
|
||||
== f"""
|
||||
| | ('SUM(num)',) | ('MAX(num)',) |
|
||||
|:-----------------|----------------:|----------------:|
|
||||
| ('{_("Total")} (Sum)',) | 8.06797e+07 | 37296 |
|
||||
| ('{_("Total")} (Sum)',) | 80679663 | 37296 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -187,7 +187,7 @@ def test_pivot_df_no_cols_no_rows_two_metrics():
|
||||
== """
|
||||
| | ('SUM(num)',) | ('MAX(num)',) |
|
||||
|:-----------------|----------------:|----------------:|
|
||||
| ('Total (Sum)',) | 8.06797e+07 | 37296 |
|
||||
| ('Total (Sum)',) | 80679663 | 37296 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -210,8 +210,8 @@ def test_pivot_df_no_cols_no_rows_two_metrics():
|
||||
== f"""
|
||||
| | ('{_("Total")} (Sum)',) |
|
||||
|:--------------|-------------------:|
|
||||
| ('SUM(num)',) | 8.06797e+07 |
|
||||
| ('MAX(num)',) | 37296 |
|
||||
| ('SUM(num)',) | 80679663 |
|
||||
| ('MAX(num)',) | 37296 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -234,7 +234,7 @@ def test_pivot_df_no_cols_no_rows_two_metrics():
|
||||
== f"""
|
||||
| | ('SUM(num)',) | ('MAX(num)',) | ('{_("Total")} (Sum)',) |
|
||||
|:-----------------|----------------:|----------------:|-------------------:|
|
||||
| ('{_("Total")} (Sum)',) | 8.06797e+07 | 37296 | 8.0717e+07 |
|
||||
| ('{_("Total")} (Sum)',) | 80679663 | 37296 | 80716959 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -1839,8 +1839,8 @@ def test_table():
|
||||
assert (
|
||||
formatted.to_markdown()
|
||||
== """
|
||||
| | count |
|
||||
|---:|:-----------|
|
||||
| | count |
|
||||
|---:|-----------:|
|
||||
| 0 | 80,679,663 |
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -188,6 +188,74 @@ def test_file_content_null_chart_customization_config_does_not_raise():
|
||||
assert result["metadata"]["chart_customization_config"] is None
|
||||
|
||||
|
||||
def test_file_content_includes_roles_for_dashboard_with_role_restrictions():
|
||||
"""
|
||||
Regression guard for #21000: dashboards restricted via DASHBOARD_RBAC must
|
||||
have their role assignments included in the exported YAML. Without this,
|
||||
importing the bundle into another environment recreates the dashboard with
|
||||
no role restriction — silently turning a restricted dashboard into a
|
||||
publicly accessible one.
|
||||
|
||||
The export bundle is the canonical source of truth for migrating
|
||||
dashboards across environments; dropping roles silently is a security
|
||||
regression (a "least privilege" dashboard becomes "all privileges" on
|
||||
import). The user, not the export pipeline, should decide whether to
|
||||
strip roles before sharing a bundle.
|
||||
|
||||
We assert against role *names* rather than IDs because role IDs are
|
||||
environment-local; the import side resolves names back to the destination
|
||||
environment's roles.
|
||||
"""
|
||||
from superset.commands.dashboard.export import ExportDashboardsCommand
|
||||
|
||||
role_alpha = MagicMock()
|
||||
role_alpha.name = "Finance"
|
||||
role_beta = MagicMock()
|
||||
role_beta.name = "Executives"
|
||||
|
||||
mock_dashboard = _make_mock_dashboard({"native_filter_configuration": []})
|
||||
mock_dashboard.roles = [role_alpha, role_beta]
|
||||
|
||||
with patch(
|
||||
"superset.commands.dashboard.export.feature_flag_manager.is_feature_enabled",
|
||||
return_value=False,
|
||||
):
|
||||
content = ExportDashboardsCommand._file_content(mock_dashboard)
|
||||
|
||||
result = yaml.safe_load(content)
|
||||
assert "roles" in result, (
|
||||
"Dashboard export must include role names; without them, importing "
|
||||
"into a fresh environment loses the role-based access restriction "
|
||||
"and the dashboard becomes accessible to all roles by default."
|
||||
)
|
||||
assert sorted(result["roles"]) == ["Executives", "Finance"]
|
||||
|
||||
|
||||
def test_file_content_omits_roles_field_when_dashboard_has_no_roles():
|
||||
"""
|
||||
A dashboard with no role restrictions must not emit an empty ``roles: []``
|
||||
key. Older bundles in the wild were written without the key at all, and
|
||||
the import side treats "missing" as "no restriction"; emitting an empty
|
||||
list could trip importers that distinguish the two states.
|
||||
"""
|
||||
from superset.commands.dashboard.export import ExportDashboardsCommand
|
||||
|
||||
mock_dashboard = _make_mock_dashboard({"native_filter_configuration": []})
|
||||
mock_dashboard.roles = []
|
||||
|
||||
with patch(
|
||||
"superset.commands.dashboard.export.feature_flag_manager.is_feature_enabled",
|
||||
return_value=False,
|
||||
):
|
||||
content = ExportDashboardsCommand._file_content(mock_dashboard)
|
||||
|
||||
result = yaml.safe_load(content)
|
||||
# Strict: the key must be absent (not an empty list). The import side
|
||||
# treats "missing" as "no restriction"; emitting an empty list could
|
||||
# trip importers that distinguish the two states.
|
||||
assert "roles" not in result
|
||||
|
||||
|
||||
def test_file_content_missing_dataset_preserves_dataset_id():
|
||||
"""
|
||||
When DatasetDAO.find_by_id returns None for a display control target,
|
||||
|
||||
@@ -295,3 +295,55 @@ async def test_successful_add(
|
||||
assert "/superset/dashboard/1/" in content["dashboard_url"]
|
||||
assert content["position"] is not None
|
||||
assert "chart_key" in content["position"]
|
||||
|
||||
|
||||
def test_empty_target_tab_rejected_by_schema() -> None:
|
||||
"""Empty string target_tab is rejected at schema layer, not as 'Tab not found'."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.dashboard.schemas import AddChartToDashboardRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AddChartToDashboardRequest(dashboard_id=1, chart_id=10, target_tab="")
|
||||
|
||||
# None is valid (tab omitted)
|
||||
req = AddChartToDashboardRequest(dashboard_id=1, chart_id=10, target_tab=None)
|
||||
assert req.target_tab is None
|
||||
|
||||
|
||||
def test_add_chart_response_error_is_sanitized_for_llm_context() -> None:
|
||||
"""Error field wraps user-supplied target_tab and dashboard tab labels.
|
||||
|
||||
The error string echoes user-provided input (target_tab) and
|
||||
dashboard-controlled tab labels. Both must be wrapped in
|
||||
UNTRUSTED-CONTENT delimiters so the LLM treats them as data, not
|
||||
instructions.
|
||||
"""
|
||||
from superset.mcp_service.dashboard.schemas import AddChartToDashboardResponse
|
||||
from superset.mcp_service.utils.sanitization import (
|
||||
LLM_CONTEXT_CLOSE_DELIMITER,
|
||||
LLM_CONTEXT_OPEN_DELIMITER,
|
||||
)
|
||||
|
||||
raw_error = (
|
||||
"Tab 'malicious tab <script>alert(1)</script>' not found in dashboard 42. "
|
||||
"Available tabs: Sales (TAB-abc), <b>Marketing</b> (TAB-xyz)."
|
||||
)
|
||||
response = AddChartToDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
position=None,
|
||||
error=raw_error,
|
||||
)
|
||||
|
||||
assert response.error is not None
|
||||
assert LLM_CONTEXT_OPEN_DELIMITER in response.error
|
||||
assert LLM_CONTEXT_CLOSE_DELIMITER in response.error
|
||||
# Core text is still present inside the wrapper
|
||||
assert "not found" in response.error
|
||||
assert "Available tabs" in response.error
|
||||
# None error is passed through unchanged
|
||||
empty_response = AddChartToDashboardResponse(
|
||||
dashboard=None, dashboard_url=None, position=None, error=None
|
||||
)
|
||||
assert empty_response.error is None
|
||||
|
||||
@@ -32,6 +32,7 @@ from superset.mcp_service.chart.chart_utils import DatasetValidationResult
|
||||
from superset.mcp_service.dashboard.constants import generate_id
|
||||
from superset.mcp_service.dashboard.tool.add_chart_to_existing_dashboard import (
|
||||
_add_chart_to_layout,
|
||||
_collect_available_tab_names,
|
||||
_ensure_layout_structure,
|
||||
_find_next_row_position,
|
||||
_find_tab_insert_target,
|
||||
@@ -1059,6 +1060,112 @@ class TestAddChartToExistingDashboard:
|
||||
assert "TAB-tab2" in chart_parents
|
||||
assert "TAB-tab1" not in chart_parents
|
||||
|
||||
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_target_tab_not_found(
|
||||
self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server
|
||||
) -> None:
|
||||
"""target_tab specified but no matching tab → descriptive error listing
|
||||
available tabs, not a silent fallback to the first tab."""
|
||||
mock_dashboard = _mock_dashboard(id=3, title="Tabbed Dashboard")
|
||||
mock_dashboard.slices = [_mock_chart(id=10)]
|
||||
mock_dashboard.position_json = json.dumps(
|
||||
{
|
||||
"ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"},
|
||||
"GRID_ID": {
|
||||
"children": ["TABS-abc123"],
|
||||
"id": "GRID_ID",
|
||||
"parents": ["ROOT_ID"],
|
||||
"type": "GRID",
|
||||
},
|
||||
"TABS-abc123": {
|
||||
"children": ["TAB-tab1", "TAB-tab2"],
|
||||
"id": "TABS-abc123",
|
||||
"parents": ["ROOT_ID", "GRID_ID"],
|
||||
"type": "TABS",
|
||||
},
|
||||
"TAB-tab1": {
|
||||
"children": [],
|
||||
"id": "TAB-tab1",
|
||||
"meta": {"text": "Overview"},
|
||||
"parents": ["ROOT_ID", "GRID_ID", "TABS-abc123"],
|
||||
"type": "TAB",
|
||||
},
|
||||
"TAB-tab2": {
|
||||
"children": [],
|
||||
"id": "TAB-tab2",
|
||||
"meta": {"text": "Details"},
|
||||
"parents": ["ROOT_ID", "GRID_ID", "TABS-abc123"],
|
||||
"type": "TAB",
|
||||
},
|
||||
"DASHBOARD_VERSION_KEY": "v2",
|
||||
}
|
||||
)
|
||||
mock_chart = _mock_chart(id=30)
|
||||
mock_db_session.get.return_value = mock_chart
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
request = {"dashboard_id": 3, "chart_id": 30, "target_tab": "Nonexistent Tab"}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.structured_content["error"] is not None
|
||||
error = result.structured_content["error"]
|
||||
assert "Nonexistent Tab" in error
|
||||
assert "not found" in error
|
||||
# Available tabs listed with both label and component ID
|
||||
assert "Overview" in error
|
||||
assert "Details" in error
|
||||
assert "TAB-tab1" in error
|
||||
assert "TAB-tab2" in error
|
||||
# No layout mutation should have been persisted
|
||||
mock_update_command.assert_not_called()
|
||||
|
||||
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_target_tab_on_non_tabbed_dashboard(
|
||||
self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server
|
||||
) -> None:
|
||||
"""target_tab on a dashboard with no tabs → descriptive error."""
|
||||
mock_dashboard = _mock_dashboard(id=5, title="Flat Dashboard")
|
||||
mock_dashboard.slices = []
|
||||
mock_dashboard.position_json = json.dumps(
|
||||
{
|
||||
"ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"},
|
||||
"GRID_ID": {
|
||||
"children": [],
|
||||
"id": "GRID_ID",
|
||||
"parents": ["ROOT_ID"],
|
||||
"type": "GRID",
|
||||
},
|
||||
"DASHBOARD_VERSION_KEY": "v2",
|
||||
}
|
||||
)
|
||||
mock_chart = _mock_chart(id=99)
|
||||
mock_db_session.get.return_value = mock_chart
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
request = {"dashboard_id": 5, "chart_id": 99, "target_tab": "Sales"}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.structured_content["error"] is not None
|
||||
error = result.structured_content["error"]
|
||||
assert "no tabs" in error.lower()
|
||||
assert "target_tab" in error
|
||||
# No layout mutation should have been persisted
|
||||
mock_update_command.assert_not_called()
|
||||
|
||||
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@@ -1311,9 +1418,9 @@ class TestLayoutHelpers:
|
||||
}
|
||||
assert _find_tab_insert_target(layout, target_tab="TAB-second") == "TAB-second"
|
||||
|
||||
def test_find_tab_insert_target_unmatched_falls_back_to_first(self):
|
||||
"""Test _find_tab_insert_target falls back to first tab when target_tab
|
||||
doesn't match any tab name or ID."""
|
||||
def test_find_tab_insert_target_unmatched_returns_none(self):
|
||||
"""Test _find_tab_insert_target returns None when target_tab doesn't
|
||||
match any tab name or ID, so the caller can return a descriptive error."""
|
||||
layout = {
|
||||
"GRID_ID": {"children": ["TABS-main"], "type": "GRID"},
|
||||
"TABS-main": {"children": ["TAB-first", "TAB-second"], "type": "TABS"},
|
||||
@@ -1328,11 +1435,19 @@ class TestLayoutHelpers:
|
||||
"meta": {"text": "Tab 2"},
|
||||
},
|
||||
}
|
||||
assert (
|
||||
_find_tab_insert_target(layout, target_tab="Nonexistent Tab") == "TAB-first"
|
||||
)
|
||||
assert _find_tab_insert_target(layout, target_tab="Nonexistent Tab") is None
|
||||
|
||||
def test_find_tab_insert_target_tabs_under_root(self):
|
||||
def test_find_tab_insert_target_empty_string_returns_none(self) -> None:
|
||||
"""An empty-string target_tab is treated as specified-but-not-found,
|
||||
not as 'no tab requested', so it returns None rather than first tab."""
|
||||
layout = {
|
||||
"GRID_ID": {"children": ["TABS-main"], "type": "GRID"},
|
||||
"TABS-main": {"children": ["TAB-first"], "type": "TABS"},
|
||||
"TAB-first": {"children": [], "type": "TAB", "meta": {"text": "Tab 1"}},
|
||||
}
|
||||
assert _find_tab_insert_target(layout, target_tab="") is None
|
||||
|
||||
def test_find_tab_insert_target_tabs_under_root(self) -> None:
|
||||
"""Test _find_tab_insert_target when TABS are under ROOT_ID (real layout)."""
|
||||
layout = {
|
||||
"ROOT_ID": {"children": ["TABS-xxx"], "type": "ROOT"},
|
||||
@@ -1343,7 +1458,7 @@ class TestLayoutHelpers:
|
||||
}
|
||||
assert _find_tab_insert_target(layout) == "TAB-a"
|
||||
|
||||
def test_find_tab_insert_target_tabs_under_root_by_name(self):
|
||||
def test_find_tab_insert_target_tabs_under_root_by_name(self) -> None:
|
||||
"""Test _find_tab_insert_target matches tab name when TABS under ROOT_ID."""
|
||||
layout = {
|
||||
"ROOT_ID": {"children": ["TABS-xxx"], "type": "ROOT"},
|
||||
@@ -1354,10 +1469,51 @@ class TestLayoutHelpers:
|
||||
}
|
||||
assert _find_tab_insert_target(layout, target_tab="Details") == "TAB-b"
|
||||
|
||||
def test_find_tab_insert_target_no_grid(self):
|
||||
def test_find_tab_insert_target_no_grid(self) -> None:
|
||||
"""Test _find_tab_insert_target with missing GRID_ID."""
|
||||
assert _find_tab_insert_target({"ROOT_ID": {"type": "ROOT"}}) is None
|
||||
|
||||
def test_collect_available_tab_names_returns_display_names(self) -> None:
|
||||
"""_collect_available_tab_names returns label + component ID for each tab."""
|
||||
layout = {
|
||||
"GRID_ID": {"children": ["TABS-x"], "type": "GRID"},
|
||||
"TABS-x": {"children": ["TAB-a", "TAB-b"], "type": "TABS"},
|
||||
"TAB-a": {"children": [], "type": "TAB", "meta": {"text": "Overview"}},
|
||||
"TAB-b": {"children": [], "type": "TAB", "meta": {"text": "Details"}},
|
||||
}
|
||||
names = _collect_available_tab_names(layout)
|
||||
assert names == ["Overview (TAB-a)", "Details (TAB-b)"]
|
||||
|
||||
def test_collect_available_tab_names_falls_back_to_id(self) -> None:
|
||||
"""_collect_available_tab_names uses component ID only when text is empty."""
|
||||
layout = {
|
||||
"GRID_ID": {"children": ["TABS-x"], "type": "GRID"},
|
||||
"TABS-x": {"children": ["TAB-a"], "type": "TABS"},
|
||||
"TAB-a": {"children": [], "type": "TAB", "meta": {}},
|
||||
}
|
||||
names = _collect_available_tab_names(layout)
|
||||
assert names == ["TAB-a"]
|
||||
|
||||
def test_collect_available_tab_names_duplicate_names(self) -> None:
|
||||
"""Duplicate display names are disambiguated by component ID in the entry."""
|
||||
layout = {
|
||||
"GRID_ID": {"children": ["TABS-x"], "type": "GRID"},
|
||||
"TABS-x": {"children": ["TAB-a", "TAB-b"], "type": "TABS"},
|
||||
"TAB-a": {"children": [], "type": "TAB", "meta": {"text": "Sales"}},
|
||||
"TAB-b": {"children": [], "type": "TAB", "meta": {"text": "Sales"}},
|
||||
}
|
||||
names = _collect_available_tab_names(layout)
|
||||
assert names == ["Sales (TAB-a)", "Sales (TAB-b)"]
|
||||
assert names[0] != names[1]
|
||||
|
||||
def test_collect_available_tab_names_no_tabs(self) -> None:
|
||||
"""_collect_available_tab_names returns empty list for non-tabbed dashboards."""
|
||||
layout = {
|
||||
"GRID_ID": {"children": ["ROW-1"], "type": "GRID"},
|
||||
"ROW-1": {"children": [], "type": "ROW"},
|
||||
}
|
||||
assert _collect_available_tab_names(layout) == []
|
||||
|
||||
def test_add_chart_to_layout_creates_column(self):
|
||||
"""Test that _add_chart_to_layout creates ROW > COLUMN > CHART."""
|
||||
layout = {
|
||||
|
||||
330
tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py
Normal file
330
tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for create_dataset MCP tool."""
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_mock_dataset(
|
||||
dataset_id: int = 42,
|
||||
table_name: str = "orders",
|
||||
schema: str = "public",
|
||||
database_name: str = "main_db",
|
||||
) -> MagicMock:
|
||||
dataset = MagicMock()
|
||||
dataset.id = dataset_id
|
||||
dataset.table_name = table_name
|
||||
dataset.schema = schema
|
||||
dataset.description = None
|
||||
dataset.changed_by_name = "admin"
|
||||
dataset.changed_on = None
|
||||
dataset.changed_on_humanized = None
|
||||
dataset.created_by_name = "admin"
|
||||
dataset.created_on = None
|
||||
dataset.created_on_humanized = None
|
||||
dataset.tags = []
|
||||
dataset.owners = []
|
||||
dataset.is_virtual = False
|
||||
dataset.database_id = 1
|
||||
dataset.certified_by = None
|
||||
dataset.certification_details = None
|
||||
dataset.schema_perm = f"[{database_name}].[{schema}]"
|
||||
dataset.url = f"/tablemodelview/edit/{dataset_id}"
|
||||
dataset.database = MagicMock()
|
||||
dataset.database.database_name = database_name
|
||||
dataset.sql = None
|
||||
dataset.main_dttm_col = None
|
||||
dataset.offset = 0
|
||||
dataset.cache_timeout = 0
|
||||
dataset.params = {}
|
||||
dataset.template_params = {}
|
||||
dataset.extra = {}
|
||||
dataset.uuid = f"dataset-uuid-{dataset_id}"
|
||||
dataset.columns = []
|
||||
dataset.metrics = []
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
class TestCreateDataset:
|
||||
"""Tests for the create_dataset MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_success(self, mcp_server):
|
||||
"""Happy path: tool creates dataset and returns DatasetInfo."""
|
||||
mock_dataset = _make_mock_dataset()
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.commands.dataset.create.CreateDatasetCommand",
|
||||
return_value=mock_command,
|
||||
) as mock_command_class,
|
||||
patch(
|
||||
"superset.mcp_service.utils.url_utils.get_superset_base_url",
|
||||
return_value="http://localhost:8088",
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 42
|
||||
assert data["table_name"] == "orders"
|
||||
assert data["schema"] == "public"
|
||||
|
||||
# Verify the command was called with the right properties
|
||||
call_kwargs = mock_command_class.call_args[0][0]
|
||||
assert call_kwargs["database"] == 1
|
||||
assert call_kwargs["schema"] == "public"
|
||||
assert call_kwargs["table_name"] == "orders"
|
||||
assert "owners" not in call_kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_with_owners(self, mcp_server):
|
||||
"""Owners list is forwarded to the command when supplied."""
|
||||
mock_dataset = _make_mock_dataset()
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.commands.dataset.create.CreateDatasetCommand",
|
||||
return_value=mock_command,
|
||||
) as mock_command_class,
|
||||
patch(
|
||||
"superset.mcp_service.utils.url_utils.get_superset_base_url",
|
||||
return_value="http://localhost:8088",
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 2,
|
||||
"schema": "sales",
|
||||
"table_name": "transactions",
|
||||
"owners": [5, 10],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 42
|
||||
|
||||
call_kwargs = mock_command_class.call_args[0][0]
|
||||
assert call_kwargs["owners"] == [5, 10]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_already_exists(self, mcp_server):
|
||||
"""Returns DatasetError when a dataset for the table already exists."""
|
||||
from superset.commands.dataset.exceptions import DatasetExistsValidationError
|
||||
from superset.sql.parse import Table
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = DatasetExistsValidationError(
|
||||
Table("orders", "public", None)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"superset.commands.dataset.create.CreateDatasetCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "DatasetExistsError"
|
||||
assert "error" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_table_not_found(self, mcp_server):
|
||||
"""Returns DatasetError when the physical table does not exist in the DB."""
|
||||
from superset.commands.dataset.exceptions import TableNotFoundValidationError
|
||||
from superset.sql.parse import Table
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = TableNotFoundValidationError(
|
||||
Table("missing_table", "public", None)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"superset.commands.dataset.create.CreateDatasetCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "missing_table",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "TableNotFoundError"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_unexpected_error(self, mcp_server):
|
||||
"""Unexpected exceptions are caught and returned as InternalError."""
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = RuntimeError("DB connection lost")
|
||||
|
||||
with patch(
|
||||
"superset.commands.dataset.create.CreateDatasetCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"table_name": "orders",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "InternalError"
|
||||
assert "DB connection lost" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_missing_required_fields(self, mcp_server):
|
||||
"""Missing required fields raise a validation error before the tool runs."""
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError):
|
||||
await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
# database_id and table_name are omitted intentionally
|
||||
"schema": "public",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_dataset_returns_full_dataset_info(self, mcp_server):
|
||||
"""The returned DatasetInfo includes columns, metrics, and all core fields."""
|
||||
mock_dataset = _make_mock_dataset(
|
||||
dataset_id=99, table_name="sales", schema="dw"
|
||||
)
|
||||
|
||||
col = MagicMock()
|
||||
col.column_name = "amount"
|
||||
col.verbose_name = "Amount"
|
||||
col.type = "NUMERIC"
|
||||
col.is_dttm = False
|
||||
col.groupby = True
|
||||
col.filterable = True
|
||||
col.description = "Sale amount"
|
||||
mock_dataset.columns = [col]
|
||||
|
||||
metric = MagicMock()
|
||||
metric.metric_name = "total_sales"
|
||||
metric.verbose_name = "Total Sales"
|
||||
metric.expression = "SUM(amount)"
|
||||
metric.description = "Sum of amounts"
|
||||
metric.d3format = None
|
||||
mock_dataset.metrics = [metric]
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_dataset
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.commands.dataset.create.CreateDatasetCommand",
|
||||
return_value=mock_command,
|
||||
),
|
||||
patch(
|
||||
"superset.mcp_service.utils.url_utils.get_superset_base_url",
|
||||
return_value="http://localhost:8088",
|
||||
),
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"create_dataset",
|
||||
{
|
||||
"request": {
|
||||
"database_id": 1,
|
||||
"schema": "dw",
|
||||
"table_name": "sales",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 99
|
||||
assert data["table_name"] == "sales"
|
||||
assert data["schema"] == "dw"
|
||||
assert data["is_virtual"] is False
|
||||
assert len(data["columns"]) == 1
|
||||
assert data["columns"][0]["column_name"] == "amount"
|
||||
assert len(data["metrics"]) == 1
|
||||
assert data["metrics"][0]["metric_name"] == "total_sales"
|
||||
@@ -20,6 +20,8 @@ Unit tests for LoggingMiddleware on_call_tool() and on_message() methods.
|
||||
|
||||
Tests verify that:
|
||||
- on_call_tool() captures duration_ms and success status
|
||||
- on_call_tool() resolves call_tool proxy to actual tool name (mcp_tool)
|
||||
- on_call_tool() captures error_type on failure
|
||||
- on_message() logs non-tool messages without duration
|
||||
- _extract_context_info() extracts entity IDs from params
|
||||
"""
|
||||
@@ -65,7 +67,7 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_logs_duration_and_success(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
) -> None:
|
||||
"""on_call_tool records duration_ms and success=True on normal return."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="list_charts")
|
||||
@@ -91,8 +93,8 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_logs_failure_on_exception(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""on_call_tool records success=False when tool raises."""
|
||||
) -> None:
|
||||
"""on_call_tool records success=False and error_type when tool raises."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="execute_sql")
|
||||
call_next = AsyncMock(side_effect=ValueError("boom"))
|
||||
@@ -104,6 +106,7 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
mock_event_logger.log.assert_called_once()
|
||||
call_kwargs = mock_event_logger.log.call_args[1]
|
||||
assert call_kwargs["curated_payload"]["success"] is False
|
||||
assert call_kwargs["curated_payload"]["error_type"] == "ValueError"
|
||||
assert call_kwargs["duration_ms"] >= 0
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@@ -111,7 +114,7 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_logs_failure_on_tool_error(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
) -> None:
|
||||
"""on_call_tool records success=False when GlobalErrorHandler raises ToolError.
|
||||
|
||||
This simulates the real middleware chain: GlobalErrorHandler catches
|
||||
@@ -137,7 +140,7 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_includes_mcp_call_id_in_curated_payload(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
) -> None:
|
||||
"""on_call_tool adds mcp_call_id to curated_payload."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="list_charts")
|
||||
@@ -155,7 +158,7 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_injects_mcp_call_id_into_tool_result_meta(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
) -> None:
|
||||
"""on_call_tool injects mcp_call_id into ToolResult.meta."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="list_charts")
|
||||
@@ -173,7 +176,7 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_preserves_existing_meta(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
) -> None:
|
||||
"""on_call_tool merges mcp_call_id with existing ToolResult.meta."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="list_charts")
|
||||
@@ -193,7 +196,7 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_extracts_entity_ids(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
) -> None:
|
||||
"""on_call_tool extracts dashboard_id, chart_id, dataset_id from params."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(
|
||||
@@ -222,7 +225,7 @@ class TestLoggingMiddlewareOnMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_logs_without_duration(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
) -> None:
|
||||
"""on_message logs with action=mcp_message and duration_ms=None."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(method="resources/read", name="instance/metadata")
|
||||
@@ -240,12 +243,124 @@ class TestLoggingMiddlewareOnMessage:
|
||||
# on_message should NOT have success field
|
||||
assert "success" not in call_kwargs["curated_payload"]
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_no_error_type_on_success(
|
||||
self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
|
||||
) -> None:
|
||||
"""on_call_tool omits error_type from payload on success."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="list_charts")
|
||||
call_next = AsyncMock(return_value="ok")
|
||||
|
||||
await middleware.on_call_tool(ctx, call_next)
|
||||
|
||||
payload = mock_event_logger.log.call_args[1]["curated_payload"]
|
||||
assert "error_type" not in payload
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_resolves_call_tool_proxy(
|
||||
self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
|
||||
) -> None:
|
||||
"""call_tool proxy is resolved to the actual tool name via mcp_tool."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(
|
||||
name="call_tool",
|
||||
params={"name": "list_datasets", "arguments": {"page": 1}},
|
||||
)
|
||||
call_next = AsyncMock(return_value="datasets")
|
||||
|
||||
await middleware.on_call_tool(ctx, call_next)
|
||||
|
||||
payload = mock_event_logger.log.call_args[1]["curated_payload"]
|
||||
assert payload["tool"] == "call_tool"
|
||||
assert payload["mcp_tool"] == "list_datasets"
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_no_mcp_tool_for_direct_calls(
|
||||
self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
|
||||
) -> None:
|
||||
"""Direct tool calls (not via proxy) omit mcp_tool from payload."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="list_charts")
|
||||
call_next = AsyncMock(return_value="charts")
|
||||
|
||||
await middleware.on_call_tool(ctx, call_next)
|
||||
|
||||
payload = mock_event_logger.log.call_args[1]["curated_payload"]
|
||||
assert payload["tool"] == "list_charts"
|
||||
assert "mcp_tool" not in payload
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_proxy_failure_captures_both_fields(
|
||||
self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
|
||||
) -> None:
|
||||
"""call_tool proxy failure captures mcp_tool and error_type."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(
|
||||
name="call_tool",
|
||||
params={"name": "get_chart_data", "arguments": {"chart_id": 1}},
|
||||
)
|
||||
call_next = AsyncMock(side_effect=PermissionError("access denied"))
|
||||
|
||||
with pytest.raises(PermissionError):
|
||||
await middleware.on_call_tool(ctx, call_next)
|
||||
|
||||
payload = mock_event_logger.log.call_args[1]["curated_payload"]
|
||||
assert payload["tool"] == "call_tool"
|
||||
assert payload["mcp_tool"] == "get_chart_data"
|
||||
assert payload["success"] is False
|
||||
assert payload["error_type"] == "PermissionError"
|
||||
|
||||
|
||||
class TestResolveToolName:
|
||||
"""Tests for LoggingMiddleware._resolve_tool_name()."""
|
||||
|
||||
def test_resolves_call_tool_proxy(self) -> None:
|
||||
"""Returns the real tool name when call_tool proxy is used."""
|
||||
assert (
|
||||
LoggingMiddleware._resolve_tool_name(
|
||||
"call_tool", {"name": "list_datasets", "arguments": {}}
|
||||
)
|
||||
== "list_datasets"
|
||||
)
|
||||
|
||||
def test_returns_none_for_direct_tool(self) -> None:
|
||||
"""Returns None for direct tool calls (not via proxy)."""
|
||||
assert LoggingMiddleware._resolve_tool_name("list_charts", {"page": 1}) is None
|
||||
|
||||
def test_returns_none_when_name_missing(self) -> None:
|
||||
"""Returns None when call_tool params lack 'name'."""
|
||||
assert LoggingMiddleware._resolve_tool_name("call_tool", {"foo": "bar"}) is None
|
||||
|
||||
def test_returns_none_for_empty_name(self) -> None:
|
||||
"""Returns None when call_tool params have empty 'name'."""
|
||||
assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": ""}) is None
|
||||
|
||||
def test_returns_none_for_non_string_name(self) -> None:
|
||||
"""Returns None when call_tool name param is not a string."""
|
||||
assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": 123}) is None
|
||||
|
||||
def test_returns_none_for_search_tools(self) -> None:
|
||||
"""search_tools proxy is not resolved (no underlying tool name)."""
|
||||
assert (
|
||||
LoggingMiddleware._resolve_tool_name("search_tools", {"query": "datasets"})
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
class TestExtractContextInfo:
|
||||
"""Tests for LoggingMiddleware._extract_context_info()."""
|
||||
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=99)
|
||||
def test_extract_with_metadata_agent_id(self, mock_get_user_id):
|
||||
def test_extract_with_metadata_agent_id(self, mock_get_user_id) -> None:
|
||||
"""Extracts agent_id from context.metadata."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(metadata={"agent_id": "agent-123"})
|
||||
@@ -261,7 +376,7 @@ class TestExtractContextInfo:
|
||||
"superset.mcp_service.middleware.get_user_id",
|
||||
side_effect=RuntimeError("no Flask request context"),
|
||||
)
|
||||
def test_extract_handles_missing_user(self, mock_get_user_id):
|
||||
def test_extract_handles_missing_user(self, mock_get_user_id) -> None:
|
||||
"""Gracefully handles missing user context."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context()
|
||||
@@ -273,7 +388,7 @@ class TestExtractContextInfo:
|
||||
assert user_id is None
|
||||
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=1)
|
||||
def test_extract_slice_id_from_chart_id(self, mock_get_user_id):
|
||||
def test_extract_slice_id_from_chart_id(self, mock_get_user_id) -> None:
|
||||
"""Extracts slice_id from chart_id param (alias)."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(params={"chart_id": 55})
|
||||
@@ -283,7 +398,7 @@ class TestExtractContextInfo:
|
||||
assert slice_id == 55
|
||||
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=1)
|
||||
def test_extract_slice_id_from_slice_id(self, mock_get_user_id):
|
||||
def test_extract_slice_id_from_slice_id(self, mock_get_user_id) -> None:
|
||||
"""Extracts slice_id from slice_id param (fallback)."""
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(params={"slice_id": 66})
|
||||
@@ -296,7 +411,7 @@ class TestExtractContextInfo:
|
||||
class TestIsErrorResponse:
|
||||
"""Tests for LoggingMiddleware._is_error_response()."""
|
||||
|
||||
def test_detects_error_schema_response(self):
|
||||
def test_detects_error_schema_response(self) -> None:
|
||||
"""Detects ToolResult containing a serialized error schema
|
||||
(ChartError, DashboardError, etc.) via "error_type" field."""
|
||||
middleware = LoggingMiddleware()
|
||||
@@ -308,7 +423,7 @@ class TestIsErrorResponse:
|
||||
result = ToolResult(content=[mt.TextContent(type="text", text=error_json)])
|
||||
assert middleware._is_error_response(result) is True
|
||||
|
||||
def test_success_response_not_detected_as_error(self):
|
||||
def test_success_response_not_detected_as_error(self) -> None:
|
||||
"""Normal ToolResult is not detected as error."""
|
||||
middleware = LoggingMiddleware()
|
||||
result = ToolResult(
|
||||
@@ -316,7 +431,7 @@ class TestIsErrorResponse:
|
||||
)
|
||||
assert middleware._is_error_response(result) is False
|
||||
|
||||
def test_empty_content_not_detected_as_error(self):
|
||||
def test_empty_content_not_detected_as_error(self) -> None:
|
||||
"""ToolResult with empty content is not detected as error."""
|
||||
middleware = LoggingMiddleware()
|
||||
assert middleware._is_error_response(ToolResult(content=[])) is False
|
||||
@@ -326,7 +441,7 @@ class TestIsErrorResponse:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_logs_failure_for_error_schema(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
) -> None:
|
||||
"""on_call_tool logs success=False when tool returns an
|
||||
error schema (e.g. ChartError)."""
|
||||
middleware = LoggingMiddleware()
|
||||
@@ -366,7 +481,7 @@ class TestMiddlewareChainOrder:
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_middleware_chain_logs_exception_as_failure(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
) -> None:
|
||||
"""Tool exception is logged as success=False through the
|
||||
real middleware chain from build_middleware_list()."""
|
||||
from superset.mcp_service.server import build_middleware_list
|
||||
@@ -413,7 +528,7 @@ class TestMiddlewareChainOrder:
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_middleware_chain_error_result_has_mcp_call_id(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
) -> None:
|
||||
"""When a tool raises, the error ToolResult from
|
||||
StructuredContentStripper still carries mcp_call_id in meta."""
|
||||
from superset.mcp_service.server import build_middleware_list
|
||||
@@ -435,3 +550,31 @@ class TestMiddlewareChainOrder:
|
||||
assert result.meta is not None
|
||||
assert "mcp_call_id" in result.meta
|
||||
assert len(result.meta["mcp_call_id"]) == 32
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_exception_returns_empty_list(self):
|
||||
"""Exception during tools/list returns [] instead of causing encoding error.
|
||||
|
||||
ToolError raised by GlobalErrorHandlerMiddleware cannot be encoded
|
||||
by the MCP SDK in a tools/list response, producing "encoding without
|
||||
a string argument". StructuredContentStripperMiddleware.on_list_tools
|
||||
must catch it and return an empty list.
|
||||
"""
|
||||
from superset.mcp_service.server import build_middleware_list
|
||||
|
||||
middleware_list = build_middleware_list()
|
||||
|
||||
async def failing_list_tools(context: Any) -> Any:
|
||||
raise ValueError("auth failed")
|
||||
|
||||
chain = failing_list_tools
|
||||
for mw in reversed(middleware_list):
|
||||
chain = partial(mw, call_next=chain)
|
||||
|
||||
ctx = _make_context(method="tools/list", name="")
|
||||
result = await chain(ctx)
|
||||
|
||||
assert result == [], (
|
||||
"on_list_tools must return [] on exception — "
|
||||
"ToolError cannot be encoded in a tools/list response."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user