Compare commits

..

21 Commits

Author SHA1 Message Date
Beto Dealmeida
285d7b1a08 Set query_context on chart creation 2026-05-05 10:33:24 -04:00
Beto Dealmeida
77016b3e89 Fix lint 2026-05-02 09:02:05 -04:00
Beto Dealmeida
32563ffb1d Add coverage 2026-05-02 00:05:49 -04:00
Beto Dealmeida
f79c7aca9d Add license 2026-05-01 23:24:44 -04:00
Beto Dealmeida
80cf2648f2 Fix lint 2026-05-01 22:58:12 -04:00
Beto Dealmeida
230c903e6b Fix lint 2026-05-01 19:28:31 -04:00
Beto Dealmeida
229917b9b0 feat: nodejs sidecar 2026-05-01 19:16:27 -04:00
Elizabeth Thompson
98eaaaa6d6 fix(mcp): clear stale thread-local DB session in sync tool wrapper (#39798)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-01 09:24:48 -07:00
Jay Masiwal
cb74438865 fix(viz): correct table chart drill-to-detail temporal boundaries and null handling (#39668)
Co-authored-by: Samuelinto <samuel.mantilla@mail.utoronto.ca>
Co-authored-by: Amin Ghadersohi <amin.ghadersohi@gmail.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-01 11:46:18 -04:00
Danylo Korostil
e77fb5e3fc feat(i18n): updated Ukrainian translation (#39720) 2026-05-01 11:12:05 -04:00
dependabot[bot]
1ac113fd44 chore(deps): bump aws-actions/amazon-ecs-render-task-definition from 1.8.4 to 1.8.5 (#39809)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-01 06:31:48 -07:00
dependabot[bot]
6bfdee98cd chore(deps-dev): bump @docusaurus/tsconfig from 3.10.0 to 3.10.1 in /docs (#39811)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-01 09:31:29 -04:00
dependabot[bot]
de45f3a928 chore(deps): bump aws-actions/amazon-ecs-deploy-task-definition from 2.6.1 to 2.6.2 (#39806)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-01 09:30:49 -04:00
dependabot[bot]
2ec53c0694 chore(deps): bump mapbox-gl from 3.22.0 to 3.23.0 in /superset-frontend (#39769)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-01 09:30:21 -04:00
Michael S. Molina
d23b0cad92 chore: Bump core packages to 0.1.0 RC3 (#39823) 2026-05-01 09:54:39 -03:00
Evan Rusackas
e585406fff chore(codeowners): notify @sfirke on translation changes (#39794)
Co-authored-by: Claude Code <noreply@anthropic.com>
2026-04-30 23:07:29 -04:00
Amin Ghadersohi
957b298ae1 fix(mcp): add default request parameter to list_charts and list_dashboards (#39730) 2026-04-30 18:04:39 -04:00
Amin Ghadersohi
f29d82b3b1 feat(mcp): add query_dataset tool to query datasets using semantic layer (#39727) 2026-04-30 18:03:41 -04:00
Vitor Avila
3f550f166f fix(GSheets OAuth2): Re-add UnauthenticatedError (#39785) 2026-04-30 18:57:00 -03:00
Vitor Avila
86eb6176d1 fix: Enforce per-user caching on legacy API endpoint (#39789) 2026-04-30 18:04:33 -03:00
Joe Li
4244ae87bf fix(deps): regenerate pinned requirements for psycopg2-binary 2.9.12 (#39790)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-30 17:46:23 -03:00
67 changed files with 10793 additions and 5762 deletions

4
.github/CODEOWNERS vendored
View File

@@ -36,6 +36,10 @@
**/*.geojson @villebro @rusackas
/superset-frontend/plugins/legacy-plugin-chart-country-map/ @villebro @rusackas
# Notify translation maintainers of changes to translations
/superset/translations/ @sfirke
# Notify PMC members of changes to extension-related files
/docs/developer_portal/extensions/ @michael-s-molina @villebro @rusackas

View File

@@ -265,7 +265,7 @@ jobs:
- name: Fill in the new image ID in the Amazon ECS task definition
id: task-def
uses: aws-actions/amazon-ecs-render-task-definition@77954e213ba1f9f9cb016b86a1d4f6fcdea0d57e # v1
uses: aws-actions/amazon-ecs-render-task-definition@6853cfae8c3a7d978fbf68b5a55453395541dfbb # v1
with:
task-definition: .github/workflows/ecs-task-definition.json
container-name: superset-ci
@@ -300,7 +300,7 @@ jobs:
--tags key=pr,value=$PR_NUMBER key=github_user,value=${{ github.actor }}
- name: Deploy Amazon ECS task definition
id: deploy-task
uses: aws-actions/amazon-ecs-deploy-task-definition@fc8fc60f3a60ffd500fcb13b209c59d221ac8c8c # v2
uses: aws-actions/amazon-ecs-deploy-task-definition@a310a830f5c14e583e35d84e4e1ec7dd177c3c9c # v2
with:
task-definition: ${{ steps.task-def.outputs.task-definition }}
service: pr-${{ github.event.inputs.issue_number || github.event.pull_request.number }}-service

View File

@@ -104,6 +104,8 @@ services:
depends_on:
superset-init:
condition: service_completed_successfully
query-context-sidecar:
condition: service_started
volumes: *superset-volumes
superset-websocket:
@@ -138,6 +140,19 @@ services:
- REDIS_PORT=6379
- REDIS_SSL=false
query-context-sidecar:
build:
context: .
dockerfile: query-context-sidecar/Dockerfile
restart: unless-stopped
ports:
- "127.0.0.1:${QUERY_CONTEXT_SIDECAR_PORT:-3030}:3030"
environment:
- PORT=3030
- QUERY_CONTEXT_MAX_BODY_BYTES=10485760
depends_on:
- superset-node
superset-init:
build:
<<: *common-build
@@ -152,6 +167,8 @@ services:
condition: service_started
redis:
condition: service_started
query-context-sidecar:
condition: service_started
user: *superset-user
volumes: *superset-volumes
healthcheck:

View File

@@ -26,6 +26,7 @@ DEV_MODE=true
# SUPERSET_PORT=8088
# NODE_PORT=9000
# WEBSOCKET_PORT=8080
# QUERY_CONTEXT_SIDECAR_PORT=3030
# CYPRESS_PORT=8081
# DATABASE_PORT=5432
# REDIS_PORT=6379
@@ -74,6 +75,7 @@ SUPERSET_LOAD_EXAMPLES=yes
CYPRESS_CONFIG=false
SUPERSET_PORT=8088
MAPBOX_API_KEY=''
QUERY_CONTEXT_SIDECAR_URL=http://query-context-sidecar:3030
# Make sure you set this to a unique secure random value on production
SUPERSET_SECRET_KEY=TEST_NON_DEV_SECRET

View File

@@ -93,7 +93,7 @@
},
"devDependencies": {
"@docusaurus/module-type-aliases": "^3.10.0",
"@docusaurus/tsconfig": "^3.10.0",
"@docusaurus/tsconfig": "^3.10.1",
"@eslint/js": "^9.39.2",
"@types/js-yaml": "^4.0.9",
"@types/react": "^19.1.8",

View File

@@ -2036,10 +2036,10 @@
fs-extra "^11.1.1"
tslib "^2.6.0"
"@docusaurus/tsconfig@^3.10.0":
version "3.10.0"
resolved "https://registry.yarnpkg.com/@docusaurus/tsconfig/-/tsconfig-3.10.0.tgz#f40a57248828f0503a5f355cf30aa59941c9baaa"
integrity sha512-TXdC3WXuPrdQAexLvjUJfnYf3YKEgEqAs5nK0Q88pRBCW7t7oN4ILvWYb3A5Z1wlSXyXGWW/mCUmLEhdWsjnDQ==
"@docusaurus/tsconfig@^3.10.1":
version "3.10.1"
resolved "https://registry.yarnpkg.com/@docusaurus/tsconfig/-/tsconfig-3.10.1.tgz#1db31b4a4a5c914bdffa80070a35b6365d34f2e8"
integrity sha512-rYvB7yqkdqWIpAbDzQljGfM4cDBkLTbhmagZBEcsyj6oPUsz47lmW2pYdN1j+7sGFgltbAmQH62xfbrij4Eh6Q==
"@docusaurus/types@3.10.0":
version "3.10.0"

2
query-context-sidecar/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
node_modules/
dist/

View File

@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Stage 1: Install superset-frontend dependencies
FROM node:20-alpine AS deps
WORKDIR /app
# Copy full superset-frontend tree so workspace dependency resolution stays consistent
COPY superset-frontend/ ./superset-frontend/
WORKDIR /app/superset-frontend
RUN npm ci --ignore-scripts
# Stage 2: Build the webpack bundle
FROM node:20-alpine AS builder
WORKDIR /app
# Copy installed node_modules from deps stage
COPY --from=deps /app/superset-frontend/node_modules ./superset-frontend/node_modules
# Copy superset-frontend source
COPY superset-frontend/ ./superset-frontend/
# Copy sidecar source and config
COPY query-context-sidecar/package.json query-context-sidecar/package-lock.json* ./query-context-sidecar/
COPY query-context-sidecar/webpack.config.js query-context-sidecar/tsconfig.json ./query-context-sidecar/
COPY query-context-sidecar/src/ ./query-context-sidecar/src/
WORKDIR /app/query-context-sidecar
RUN npm ci
RUN npm run build
# Stage 3: Minimal runtime
FROM node:20-alpine
ENV NODE_ENV=production
WORKDIR /app
COPY --from=builder /app/query-context-sidecar/dist ./dist
USER node
CMD ["node", "dist/index.js"]

2130
query-context-sidecar/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
{
"name": "query-context-sidecar",
"version": "1.0.0",
"description": "Node.js sidecar that converts form_data to query_context using Superset frontend buildQuery functions",
"private": true,
"scripts": {
"build": "webpack --mode production",
"build:dev": "webpack --mode development",
"start": "node dist/index.js",
"dev": "webpack --mode development --watch"
},
"devDependencies": {
"css-loader": "^6.8.1",
"null-loader": "^4.0.1",
"style-loader": "^3.3.3",
"ts-loader": "^9.5.1",
"typescript": "^5.3.3",
"webpack": "^5.89.0",
"webpack-cli": "^5.1.4"
}
}

View File

@@ -0,0 +1,55 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { QueryFormData } from '@superset-ui/core';
import { getBuildQuery } from '../runtimeRegistry';
export default function buildCartodiagramQuery(formData: QueryFormData) {
const {
selected_chart: selectedChartString,
geom_column: geometryColumn,
extra_form_data: extraFormData,
} = formData as QueryFormData & {
selected_chart: string;
geom_column: string;
extra_form_data?: Record<string, unknown>;
};
const selectedChart = JSON.parse(selectedChartString);
const vizType = selectedChart.viz_type as string;
const chartFormData = JSON.parse(selectedChart.params) as Record<string, unknown>;
chartFormData.extra_form_data = {
...(chartFormData.extra_form_data as Record<string, unknown>),
...(extraFormData || {}),
};
const groupby = Array.isArray(chartFormData.groupby)
? (chartFormData.groupby as string[])
: [];
chartFormData.groupby = [geometryColumn, ...groupby];
const buildQuery = getBuildQuery(vizType);
if (!buildQuery) {
throw new Error(`Unsupported selected chart viz_type: ${vizType}`);
}
return buildQuery(chartFormData);
}

View File

@@ -0,0 +1,26 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import './polyfills';
import { registerAllBuildQueries } from './registry';
import { startServer } from './server';
registerAllBuildQueries();
startServer();

View File

@@ -0,0 +1,87 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
const g = globalThis as any;
if (typeof g.window === 'undefined') {
g.window = g;
}
g.window.featureFlags = {};
if (typeof g.document === 'undefined') {
g.document = {
getElementById: () => null,
createElement: () => ({
setAttribute: () => {},
style: {},
appendChild: () => {},
}),
createTextNode: () => ({}),
head: { appendChild: () => {} },
body: { appendChild: () => {} },
addEventListener: () => {},
removeEventListener: () => {},
querySelectorAll: () => [],
querySelector: () => null,
};
}
if (typeof g.navigator === 'undefined') {
g.navigator = {
userAgent: 'node.js',
language: 'en',
};
}
if (typeof g.HTMLElement === 'undefined') {
g.HTMLElement = class HTMLElement {};
}
if (typeof g.location === 'undefined') {
g.location = {
href: '',
origin: '',
protocol: 'http:',
host: 'localhost',
hostname: 'localhost',
port: '',
pathname: '/',
search: '',
hash: '',
};
}
if (typeof g.getComputedStyle === 'undefined') {
g.getComputedStyle = () => ({});
}
if (typeof g.requestAnimationFrame === 'undefined') {
g.requestAnimationFrame = (cb: () => void) => setTimeout(cb, 0);
}
if (typeof g.matchMedia === 'undefined') {
g.matchMedia = () => ({
matches: false,
addListener: () => {},
removeListener: () => {},
addEventListener: () => {},
removeEventListener: () => {},
});
}

View File

@@ -0,0 +1,114 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import bigNumberBuildQuery from '@superset-ui/plugin-chart-echarts/BigNumber/BigNumberWithTrendline/buildQuery';
import bigNumberPoPBuildQuery from '@superset-ui/plugin-chart-echarts/BigNumber/BigNumberPeriodOverPeriod/buildQuery';
import bigNumberTotalBuildQuery from '@superset-ui/plugin-chart-echarts/BigNumber/BigNumberTotal/buildQuery';
import boxPlotBuildQuery from '@superset-ui/plugin-chart-echarts/BoxPlot/buildQuery';
import bubbleBuildQuery from '@superset-ui/plugin-chart-echarts/Bubble/buildQuery';
import funnelBuildQuery from '@superset-ui/plugin-chart-echarts/Funnel/buildQuery';
import ganttBuildQuery from '@superset-ui/plugin-chart-echarts/Gantt/buildQuery';
import gaugeBuildQuery from '@superset-ui/plugin-chart-echarts/Gauge/buildQuery';
import graphBuildQuery from '@superset-ui/plugin-chart-echarts/Graph/buildQuery';
import heatmapBuildQuery from '@superset-ui/plugin-chart-echarts/Heatmap/buildQuery';
import histogramBuildQuery from '@superset-ui/plugin-chart-echarts/Histogram/buildQuery';
import mixedTimeseriesBuildQuery from '@superset-ui/plugin-chart-echarts/MixedTimeseries/buildQuery';
import pieBuildQuery from '@superset-ui/plugin-chart-echarts/Pie/buildQuery';
import radarBuildQuery from '@superset-ui/plugin-chart-echarts/Radar/buildQuery';
import sankeyBuildQuery from '@superset-ui/plugin-chart-echarts/Sankey/buildQuery';
import sunburstBuildQuery from '@superset-ui/plugin-chart-echarts/Sunburst/buildQuery';
import timeseriesBuildQuery from '@superset-ui/plugin-chart-echarts/Timeseries/buildQuery';
import treeBuildQuery from '@superset-ui/plugin-chart-echarts/Tree/buildQuery';
import treemapBuildQuery from '@superset-ui/plugin-chart-echarts/Treemap/buildQuery';
import waterfallBuildQuery from '@superset-ui/plugin-chart-echarts/Waterfall/buildQuery';
import handlebarsBuildQuery from '@superset-ui/plugin-chart-handlebars/plugin/buildQuery';
import pivotTableBuildQuery from '@superset-ui/plugin-chart-pivot-table/plugin/buildQuery';
import wordCloudBuildQuery from '@superset-ui/plugin-chart-word-cloud/plugin/buildQuery';
import tableBuildQuery from '@superset-ui/plugin-chart-table/buildQuery';
import agGridTableBuildQuery from '@superset-ui/plugin-chart-ag-grid-table/buildQuery';
import pointClusterMapBuildQuery from '@superset-ui/plugin-chart-point-cluster-map/buildQuery';
import deckArcBuildQuery from '@superset-ui/preset-chart-deckgl/layers/Arc/buildQuery';
import deckContourBuildQuery from '@superset-ui/preset-chart-deckgl/layers/Contour/buildQuery';
import deckGridBuildQuery from '@superset-ui/preset-chart-deckgl/layers/Grid/buildQuery';
import deckHeatmapBuildQuery from '@superset-ui/preset-chart-deckgl/layers/Heatmap/buildQuery';
import deckHexBuildQuery from '@superset-ui/preset-chart-deckgl/layers/Hex/buildQuery';
import deckPathBuildQuery from '@superset-ui/preset-chart-deckgl/layers/Path/buildQuery';
import deckPolygonBuildQuery from '@superset-ui/preset-chart-deckgl/layers/Polygon/buildQuery';
import deckScatterBuildQuery from '@superset-ui/preset-chart-deckgl/layers/Scatter/buildQuery';
import deckScreengridBuildQuery from '@superset-ui/preset-chart-deckgl/layers/Screengrid/buildQuery';
import filterRangeBuildQuery from 'src/filters/components/Range/buildQuery';
import filterSelectBuildQuery from 'src/filters/components/Select/buildQuery';
import filterTimeColumnBuildQuery from 'src/filters/components/TimeColumn/buildQuery';
import filterTimeGrainBuildQuery from 'src/filters/components/TimeGrain/buildQuery';
import cartodiagramBuildQuery from './buildQuery/cartodiagram';
import { registerBuildQuery } from './runtimeRegistry';
export function registerAllBuildQueries(): void {
registerBuildQuery('big_number', bigNumberBuildQuery as any);
registerBuildQuery('big_number_total', bigNumberTotalBuildQuery as any);
registerBuildQuery('pop_kpi', bigNumberPoPBuildQuery as any);
registerBuildQuery('box_plot', boxPlotBuildQuery as any);
registerBuildQuery('bubble_v2', bubbleBuildQuery as any);
registerBuildQuery('funnel', funnelBuildQuery as any);
registerBuildQuery('gantt_chart', ganttBuildQuery as any);
registerBuildQuery('gauge_chart', gaugeBuildQuery as any);
registerBuildQuery('graph_chart', graphBuildQuery as any);
registerBuildQuery('heatmap_v2', heatmapBuildQuery as any);
registerBuildQuery('histogram_v2', histogramBuildQuery as any);
registerBuildQuery('mixed_timeseries', mixedTimeseriesBuildQuery as any);
registerBuildQuery('pie', pieBuildQuery as any);
registerBuildQuery('radar', radarBuildQuery as any);
registerBuildQuery('sankey_v2', sankeyBuildQuery as any);
registerBuildQuery('sunburst_v2', sunburstBuildQuery as any);
registerBuildQuery('tree_chart', treeBuildQuery as any);
registerBuildQuery('treemap_v2', treemapBuildQuery as any);
registerBuildQuery('waterfall', waterfallBuildQuery as any);
registerBuildQuery('echarts_timeseries', timeseriesBuildQuery as any);
registerBuildQuery('echarts_area', timeseriesBuildQuery as any);
registerBuildQuery('echarts_timeseries_bar', timeseriesBuildQuery as any);
registerBuildQuery('echarts_timeseries_line', timeseriesBuildQuery as any);
registerBuildQuery('echarts_timeseries_smooth', timeseriesBuildQuery as any);
registerBuildQuery('echarts_timeseries_scatter', timeseriesBuildQuery as any);
registerBuildQuery('echarts_timeseries_step', timeseriesBuildQuery as any);
registerBuildQuery('pivot_table_v2', pivotTableBuildQuery as any);
registerBuildQuery('table', tableBuildQuery as any);
registerBuildQuery('ag-grid-table', agGridTableBuildQuery as any);
registerBuildQuery('point_cluster', pointClusterMapBuildQuery as any);
registerBuildQuery('handlebars', handlebarsBuildQuery as any);
registerBuildQuery('word_cloud', wordCloudBuildQuery as any);
registerBuildQuery('cartodiagram', cartodiagramBuildQuery as any);
registerBuildQuery('deck_arc', deckArcBuildQuery as any);
registerBuildQuery('deck_contour', deckContourBuildQuery as any);
registerBuildQuery('deck_grid', deckGridBuildQuery as any);
registerBuildQuery('deck_heatmap', deckHeatmapBuildQuery as any);
registerBuildQuery('deck_hex', deckHexBuildQuery as any);
registerBuildQuery('deck_path', deckPathBuildQuery as any);
registerBuildQuery('deck_polygon', deckPolygonBuildQuery as any);
registerBuildQuery('deck_scatter', deckScatterBuildQuery as any);
registerBuildQuery('deck_screengrid', deckScreengridBuildQuery as any);
registerBuildQuery('filter_select', filterSelectBuildQuery as any);
registerBuildQuery('filter_range', filterRangeBuildQuery as any);
registerBuildQuery('filter_timecolumn', filterTimeColumnBuildQuery as any);
registerBuildQuery('filter_timegrain', filterTimeGrainBuildQuery as any);
}

View File

@@ -0,0 +1,34 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
export type BuildQueryFn = (formData: Record<string, unknown>) => unknown;
const registry = new Map<string, BuildQueryFn>();
export function registerBuildQuery(vizType: string, fn: BuildQueryFn): void {
registry.set(vizType, fn);
}
export function getBuildQuery(vizType: string): BuildQueryFn | undefined {
return registry.get(vizType);
}
export function listVizTypes(): string[] {
return Array.from(registry.keys()).sort();
}

View File

@@ -0,0 +1,28 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import { getBuildQuery } from './runtimeRegistry';
export default function getChartBuildQueryRegistry() {
return {
get(vizType: string) {
return getBuildQuery(vizType);
},
};
}

View File

@@ -0,0 +1,166 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import http from 'http';
import { URL } from 'url';
import buildQueryContext from './stubs/buildQueryContext';
import { getBuildQuery, listVizTypes } from './runtimeRegistry';
const PORT = parseInt(process.env.PORT || '3030', 10);
const MAX_BODY_BYTES = parseInt(
process.env.QUERY_CONTEXT_MAX_BODY_BYTES || `${10 * 1024 * 1024}`,
10,
);
const ALLOWED_ORIGINS = new Set(
(process.env.QUERY_CONTEXT_ALLOWED_ORIGINS || '')
.split(',')
.map(origin => origin.trim())
.filter(Boolean),
);
class HttpRequestError extends Error {
statusCode: number;
constructor(statusCode: number, message: string) {
super(message);
this.statusCode = statusCode;
}
}
function readBody(req: http.IncomingMessage): Promise<string> {
return new Promise((resolve, reject) => {
const chunks: Buffer[] = [];
let totalBytes = 0;
req.on('data', (chunk: Buffer) => {
totalBytes += chunk.length;
if (totalBytes > MAX_BODY_BYTES) {
req.destroy();
reject(new HttpRequestError(413, 'Request body too large'));
return;
}
chunks.push(chunk);
});
req.on('end', () => resolve(Buffer.concat(chunks).toString()));
req.on('error', reject);
});
}
function isAllowedOrigin(origin?: string): boolean {
if (!origin) {
return true;
}
if (ALLOWED_ORIGINS.size === 0) {
return true;
}
return ALLOWED_ORIGINS.has(origin);
}
function jsonResponse(res: http.ServerResponse, status: number, data: unknown): void {
res.writeHead(status, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(data));
}
async function handleBuildQueryContext(
req: http.IncomingMessage,
res: http.ServerResponse,
): Promise<void> {
if (!isAllowedOrigin(req.headers.origin)) {
jsonResponse(res, 403, { error: 'Origin not allowed' });
return;
}
let body: string;
try {
body = await readBody(req);
} catch (err: any) {
if (err instanceof HttpRequestError) {
jsonResponse(res, err.statusCode, { error: err.message });
return;
}
throw err;
}
let parsed: any;
try {
parsed = JSON.parse(body);
} catch {
jsonResponse(res, 400, { error: 'Invalid JSON body' });
return;
}
const formData = parsed.form_data;
if (!formData || !formData.viz_type) {
jsonResponse(res, 400, {
error: 'Missing form_data or form_data.viz_type',
});
return;
}
try {
const buildQuery = getBuildQuery(formData.viz_type);
const queryContext = buildQuery
? buildQuery(formData)
: buildQueryContext(formData);
jsonResponse(res, 200, { query_context: queryContext });
} catch (err: any) {
console.error('Error building query context for %s:', formData.viz_type, err);
jsonResponse(res, 500, {
error: `Failed to build query context: ${err.message}`,
});
}
}
function handleVizTypes(res: http.ServerResponse): void {
const vizTypes = listVizTypes();
jsonResponse(res, 200, { viz_types: vizTypes, count: vizTypes.length });
}
function handleHealth(res: http.ServerResponse): void {
res.writeHead(200, { 'Content-Type': 'text/plain' });
res.end('OK');
}
export function startServer(): void {
const server = http.createServer(async (req, res) => {
const url = req.url ? new URL(req.url, `http://localhost:${PORT}`).pathname : '';
const method = req.method || '';
try {
if (url === '/health' && (method === 'GET' || method === 'HEAD')) {
handleHealth(res);
} else if (url === '/api/v1/viz-types' && method === 'GET') {
handleVizTypes(res);
} else if (url === '/api/v1/build-query-context' && method === 'POST') {
await handleBuildQueryContext(req, res);
} else {
jsonResponse(res, 404, { error: 'Not found' });
}
} catch (err) {
console.error('Unhandled error:', err);
jsonResponse(res, 500, { error: 'Internal server error' });
}
});
server.listen(PORT, () => {
console.log(`Query context sidecar listening on port ${PORT}`);
});
}

View File

@@ -0,0 +1,68 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import buildQueryObject from '@superset-ui/core/query/buildQueryObject';
import DatasourceKey from '@superset-ui/core/query/DatasourceKey';
import { normalizeTimeColumn } from '@superset-ui/core/query/normalizeTimeColumn';
import { isXAxisSet } from '@superset-ui/core/query/getXAxis';
import {
QueryFieldAliases,
QueryFormData,
} from '@superset-ui/core/query/types/QueryFormData';
import { QueryContext, QueryObject } from '@superset-ui/core/query/types/Query';
const WRAP_IN_ARRAY = (baseQueryObject: QueryObject) => [baseQueryObject];
type BuildFinalQueryObjects = (baseQueryObject: QueryObject) => QueryObject[];
export default function buildQueryContext(
formData: QueryFormData,
options?:
| {
buildQuery?: BuildFinalQueryObjects;
queryFields?: QueryFieldAliases;
}
| BuildFinalQueryObjects,
): QueryContext {
const { queryFields, buildQuery = WRAP_IN_ARRAY } =
typeof options === 'function'
? { buildQuery: options, queryFields: {} }
: options || {};
let queries = buildQuery(buildQueryObject(formData, queryFields));
queries.forEach(query => {
if (Array.isArray(query.post_processing)) {
query.post_processing = query.post_processing.filter(Boolean);
}
});
if (isXAxisSet(formData)) {
queries = queries.map(query => normalizeTimeColumn(formData, query));
}
return {
datasource: new DatasourceKey(formData.datasource).toObject(),
force: formData.force || false,
queries,
form_data: formData,
result_format: formData.result_format || 'json',
result_type: formData.result_type || 'full',
};
}

View File

@@ -0,0 +1,20 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
export default {};

View File

@@ -0,0 +1,35 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
export { aggregationOperator } from '@superset-ui/chart-controls/operators/aggregateOperator';
export { boxplotOperator } from '@superset-ui/chart-controls/operators/boxplotOperator';
export { contributionOperator } from '@superset-ui/chart-controls/operators/contributionOperator';
export { flattenOperator } from '@superset-ui/chart-controls/operators/flattenOperator';
export { histogramOperator } from '@superset-ui/chart-controls/operators/histogramOperator';
export { pivotOperator } from '@superset-ui/chart-controls/operators/pivotOperator';
export { prophetOperator } from '@superset-ui/chart-controls/operators/prophetOperator';
export { rankOperator } from '@superset-ui/chart-controls/operators/rankOperator';
export { renameOperator } from '@superset-ui/chart-controls/operators/renameOperator';
export { resampleOperator } from '@superset-ui/chart-controls/operators/resampleOperator';
export { rollingWindowOperator } from '@superset-ui/chart-controls/operators/rollingWindowOperator';
export { sortOperator } from '@superset-ui/chart-controls/operators/sortOperator';
export { timeCompareOperator } from '@superset-ui/chart-controls/operators/timeCompareOperator';
export { timeComparePivotOperator } from '@superset-ui/chart-controls/operators/timeComparePivotOperator';
export { extractExtraMetrics } from '@superset-ui/chart-controls/operators/utils/extractExtraMetrics';
export { isTimeComparison } from '@superset-ui/chart-controls/operators/utils/isTimeComparison';

View File

@@ -0,0 +1,28 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
export { default as buildQueryContext } from './buildQueryContext';
export { default as getChartBuildQueryRegistry } from '../runtimeRegistryAdapter';
export type { BuildQuery } from '@superset-ui/core/chart/registries/ChartBuildQueryRegistrySingleton';
export * from '@superset-ui/core/query';
export * from '@superset-ui/core/utils';
export * from '@superset-ui/core/validator';
export * from '@superset-ui/core/color';

View File

@@ -0,0 +1,35 @@
{
"compilerOptions": {
"target": "ES2019",
"module": "ESNext",
"moduleResolution": "node",
"esModuleInterop": true,
"allowSyntheticDefaultImports": true,
"strict": false,
"skipLibCheck": true,
"resolveJsonModule": true,
"jsx": "react",
"outDir": "dist",
"baseUrl": ".",
"paths": {
"@superset-ui/core": ["../superset-frontend/packages/superset-ui-core/src"],
"@superset-ui/core/*": ["../superset-frontend/packages/superset-ui-core/src/*"],
"@apache-superset/core": ["../superset-frontend/packages/superset-core/src"],
"@apache-superset/core/*": ["../superset-frontend/packages/superset-core/src/*"],
"@superset-ui/chart-controls": ["../superset-frontend/packages/superset-ui-chart-controls/src"],
"@superset-ui/plugin-chart-echarts/*": ["../superset-frontend/plugins/plugin-chart-echarts/src/*"],
"@superset-ui/plugin-chart-table/*": ["../superset-frontend/plugins/plugin-chart-table/src/*"],
"@superset-ui/plugin-chart-pivot-table/*": ["../superset-frontend/plugins/plugin-chart-pivot-table/src/*"],
"@superset-ui/plugin-chart-handlebars/*": ["../superset-frontend/plugins/plugin-chart-handlebars/src/*"],
"@superset-ui/plugin-chart-word-cloud/*": ["../superset-frontend/plugins/plugin-chart-word-cloud/src/*"],
"@superset-ui/plugin-chart-cartodiagram/*": ["../superset-frontend/plugins/plugin-chart-cartodiagram/src/*"],
"@superset-ui/plugin-chart-ag-grid-table/*": ["../superset-frontend/plugins/plugin-chart-ag-grid-table/src/*"],
"@superset-ui/plugin-chart-point-cluster-map/*": ["../superset-frontend/plugins/plugin-chart-point-cluster-map/src/*"],
"@superset-ui/preset-chart-deckgl/*": ["../superset-frontend/plugins/preset-chart-deckgl/src/*"],
"@superset-ui/legacy-preset-chart-nvd3/*": ["../superset-frontend/plugins/legacy-preset-chart-nvd3/src/*"],
"src/*": ["../superset-frontend/src/*"]
}
},
"include": ["src/**/*"],
"exclude": ["node_modules", "dist"]
}

View File

@@ -0,0 +1,137 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
const path = require('path');
const webpack = require('webpack');
const FRONTEND_DIR = path.resolve(__dirname, '../superset-frontend');
module.exports = {
target: 'node',
mode: 'production',
entry: './src/index.ts',
output: {
filename: 'index.js',
path: path.resolve(__dirname, 'dist'),
libraryTarget: 'commonjs2',
},
resolve: {
extensions: ['.ts', '.tsx', '.js', '.jsx', '.json'],
modules: [path.join(FRONTEND_DIR, 'node_modules'), FRONTEND_DIR, 'node_modules'],
alias: {
'@superset-ui/core': path.join(FRONTEND_DIR, 'packages/superset-ui-core/src'),
'@superset-ui/chart-controls': path.join(
FRONTEND_DIR,
'packages/superset-ui-chart-controls/src',
),
'@superset-ui/switchboard': path.join(
FRONTEND_DIR,
'packages/superset-ui-switchboard/src',
),
'@apache-superset/core': path.join(FRONTEND_DIR, 'packages/superset-core/src'),
'@superset-ui/plugin-chart-echarts': path.join(
FRONTEND_DIR,
'plugins/plugin-chart-echarts/src',
),
'@superset-ui/plugin-chart-table': path.join(
FRONTEND_DIR,
'plugins/plugin-chart-table/src',
),
'@superset-ui/plugin-chart-pivot-table': path.join(
FRONTEND_DIR,
'plugins/plugin-chart-pivot-table/src',
),
'@superset-ui/plugin-chart-handlebars': path.join(
FRONTEND_DIR,
'plugins/plugin-chart-handlebars/src',
),
'@superset-ui/plugin-chart-word-cloud': path.join(
FRONTEND_DIR,
'plugins/plugin-chart-word-cloud/src',
),
'@superset-ui/plugin-chart-cartodiagram': path.join(
FRONTEND_DIR,
'plugins/plugin-chart-cartodiagram/src',
),
'@superset-ui/plugin-chart-ag-grid-table': path.join(
FRONTEND_DIR,
'plugins/plugin-chart-ag-grid-table/src',
),
'@superset-ui/plugin-chart-point-cluster-map': path.join(
FRONTEND_DIR,
'plugins/plugin-chart-point-cluster-map/src',
),
'@superset-ui/preset-chart-deckgl': path.join(
FRONTEND_DIR,
'plugins/preset-chart-deckgl/src',
),
},
},
module: {
rules: [
{
test: /\.tsx?$/,
use: {
loader: 'ts-loader',
options: {
transpileOnly: true,
configFile: path.resolve(__dirname, 'tsconfig.json'),
},
},
exclude: /node_modules/,
},
{
test: /\.(png|jpe?g|gif|svg|ico)$/i,
use: 'null-loader',
},
{
test: /\.(css|less|scss|sass)$/i,
use: 'null-loader',
},
],
},
plugins: [
new webpack.NormalModuleReplacementPlugin(
/^@superset-ui\/core$/,
path.resolve(__dirname, 'src/stubs/superset-ui-core.ts'),
),
new webpack.NormalModuleReplacementPlugin(
/^@superset-ui\/chart-controls$/,
path.resolve(__dirname, 'src/stubs/superset-ui-chart-controls.ts'),
),
new webpack.NormalModuleReplacementPlugin(
/react-markdown/,
path.resolve(__dirname, 'src/stubs/empty.ts'),
),
new webpack.NormalModuleReplacementPlugin(
/remark-rehype/,
path.resolve(__dirname, 'src/stubs/empty.ts'),
),
new webpack.NormalModuleReplacementPlugin(
/remark-gfm/,
path.resolve(__dirname, 'src/stubs/empty.ts'),
),
new webpack.DefinePlugin({
'process.env.NODE_ENV': JSON.stringify('production'),
}),
],
optimization: {
minimize: false,
},
};

View File

@@ -707,7 +707,7 @@ protobuf==4.25.8
# proto-plus
psutil==6.1.0
# via apache-superset
psycopg2-binary==2.9.9
psycopg2-binary==2.9.12
# via apache-superset
py-key-value-aio==0.4.4
# via fastmcp

View File

@@ -18,7 +18,7 @@
[project]
name = "apache-superset-core"
version = "0.1.0rc2"
version = "0.1.0rc3"
description = "Core Python package for building Apache Superset backend extensions and integrations"
readme = "README.md"
authors = [

View File

@@ -17,7 +17,7 @@
[project]
name = "apache-superset-extensions-cli"
version = "0.1.0rc2"
version = "0.1.0rc3"
description = "Official command-line interface for building, bundling, and managing Apache Superset extensions"
readme = "README.md"
authors = [

View File

@@ -102,7 +102,7 @@
"json-bigint": "^1.0.0",
"json-stringify-pretty-compact": "^2.0.0",
"lodash": "^4.18.1",
"mapbox-gl": "^3.22.0",
"mapbox-gl": "^3.23.0",
"markdown-to-jsx": "^9.7.16",
"match-sorter": "^8.3.0",
"memoize-one": "^5.2.1",
@@ -35939,9 +35939,9 @@
"license": "MIT"
},
"node_modules/mapbox-gl": {
"version": "3.22.0",
"resolved": "https://registry.npmjs.org/mapbox-gl/-/mapbox-gl-3.22.0.tgz",
"integrity": "sha512-ZIpF+oAMcQoDlvABmiRkHoydyBR9zI6CyDeVRa2/iyua0/B2+rPuIzoCV/CgN14P5F0HVk53GIZw220WSqJPyA==",
"version": "3.23.0",
"resolved": "https://registry.npmjs.org/mapbox-gl/-/mapbox-gl-3.23.0.tgz",
"integrity": "sha512-zzjNAaMNvXnAVEUrYpOWmRVEBCIWgDAMLRPvSOoKY3smKvrINFVrRK/1jEpUDbEa7Ppf5Q/nwC6E07tz/i7IKw==",
"license": "SEE LICENSE IN LICENSE.txt",
"workspaces": [
"src/style-spec",
@@ -51021,7 +51021,7 @@
"dependencies": {
"chalk": "^5.6.2",
"lodash-es": "^4.18.1",
"yeoman-generator": "^8.1.2",
"yeoman-generator": "^8.2.2",
"yosay": "^3.0.0"
},
"devDependencies": {
@@ -51395,7 +51395,7 @@
},
"packages/superset-core": {
"name": "@apache-superset/core",
"version": "0.1.0-rc2",
"version": "0.1.0-rc3",
"license": "Apache-2.0",
"devDependencies": {
"@babel/cli": "^7.28.6",
@@ -53062,7 +53062,7 @@
"license": "Apache-2.0",
"dependencies": {
"@math.gl/web-mercator": "^4.1.0",
"mapbox-gl": "^3.22.0",
"mapbox-gl": "^3.23.0",
"maplibre-gl": "^5.24.0",
"react-map-gl": "^8.1.0",
"supercluster": "^8.0.1"
@@ -53473,103 +53473,6 @@
"version": "1.0.0",
"extraneous": true,
"license": "Apache-2.0"
},
"node_modules/mem-fs-editor/node_modules/array-differ": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/array-differ/-/array-differ-4.0.0.tgz",
"integrity": "sha512-Q6VPTLMsmXZ47ENG3V+wQyZS1ZxXMxFyYzA+Z/GMrJ6yIutAIEf9wTyroTzmGjNfox9/h3GdGBCVh43GVFx4Uw==",
"license": "MIT",
"optional": true,
"peer": true,
"engines": {
"node": "^12.20.0 || ^14.13.1 || >=16.0.0"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/mem-fs-editor/node_modules/array-union": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/array-union/-/array-union-3.0.1.tgz",
"integrity": "sha512-1OvF9IbWwaeiM9VhzYXVQacMibxpXOMYVNIvMtKRyX9SImBXpKcFr8XvFDeEslCyuH/t6KRt7HEO94AlP8Iatw==",
"license": "MIT",
"optional": true,
"peer": true,
"engines": {
"node": ">=12"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/mem-fs-editor/node_modules/globby": {
"version": "14.0.2",
"resolved": "https://registry.npmjs.org/globby/-/globby-14.0.2.tgz",
"integrity": "sha512-s3Fq41ZVh7vbbe2PN3nrW7yC7U7MFVc5c98/iTl9c2GawNMKx/J648KQRW6WKkuU8GIbbh2IXfIRQjOZnXcTnw==",
"license": "MIT",
"optional": true,
"peer": true,
"dependencies": {
"@sindresorhus/merge-streams": "^2.1.0",
"fast-glob": "^3.3.2",
"ignore": "^5.2.4",
"path-type": "^5.0.0",
"slash": "^5.1.0",
"unicorn-magic": "^0.1.0"
},
"engines": {
"node": ">=18"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/mem-fs-editor/node_modules/multimatch": {
"version": "7.0.0",
"resolved": "https://registry.npmjs.org/multimatch/-/multimatch-7.0.0.tgz",
"integrity": "sha512-SYU3HBAdF4psHEL/+jXDKHO95/m5P2RvboHT2Y0WtTttvJLP4H/2WS9WlQPFvF6C8d6SpLw8vjCnQOnVIVOSJQ==",
"license": "MIT",
"optional": true,
"peer": true,
"dependencies": {
"array-differ": "^4.0.0",
"array-union": "^3.0.1",
"minimatch": "^9.0.3"
},
"engines": {
"node": ">=18"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/mem-fs-editor/node_modules/path-type": {
"version": "5.0.0",
"resolved": "https://registry.npmjs.org/path-type/-/path-type-5.0.0.tgz",
"integrity": "sha512-5HviZNaZcfqP95rwpv+1HDgUamezbqdSYTyzjTvwtJSnIH+3vnbmWsItli8OFEndS984VT55M3jduxZbX351gg==",
"license": "MIT",
"optional": true,
"peer": true,
"engines": {
"node": ">=12"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/mem-fs-editor/node_modules/slash": {
"version": "5.1.0",
"resolved": "https://registry.npmjs.org/slash/-/slash-5.1.0.tgz",
"integrity": "sha512-ZA6oR3T/pEyuqwMgAKT0/hAv8oAXckzbkmR0UkUosQ+Mc4RxGoJkRmwHgHufaenlyAgE1Mxgpdcrf75y6XcnDg==",
"license": "MIT",
"optional": true,
"peer": true,
"engines": {
"node": ">=14.16"
},
"funding": {
"url": "https://github.com/sponsors/sindresorhus"
}
}
}
}

View File

@@ -183,7 +183,7 @@
"json-bigint": "^1.0.0",
"json-stringify-pretty-compact": "^2.0.0",
"lodash": "^4.18.1",
"mapbox-gl": "^3.22.0",
"mapbox-gl": "^3.23.0",
"markdown-to-jsx": "^9.7.16",
"match-sorter": "^8.3.0",
"memoize-one": "^5.2.1",

View File

@@ -1,6 +1,6 @@
{
"name": "@apache-superset/core",
"version": "0.1.0-rc2",
"version": "0.1.0-rc3",
"description": "This package contains UI elements, APIs, and utility functions used by Superset.",
"sideEffects": false,
"main": "lib/index.js",

View File

@@ -1400,25 +1400,6 @@ test('getAxisType with forced categorical', () => {
);
});
test('getAxisType treats numeric as category for bar charts', () => {
expect(
getAxisType(
false,
false,
GenericDataType.Numeric,
EchartsTimeseriesSeriesType.Bar,
),
).toEqual(AxisType.Category);
expect(
getAxisType(
false,
false,
GenericDataType.Numeric,
EchartsTimeseriesSeriesType.Line,
),
).toEqual(AxisType.Value);
});
test('getMinAndMaxFromBounds returns empty object when not truncating', () => {
expect(
getMinAndMaxFromBounds(

View File

@@ -27,7 +27,7 @@
],
"dependencies": {
"@math.gl/web-mercator": "^4.1.0",
"mapbox-gl": "^3.22.0",
"mapbox-gl": "^3.23.0",
"maplibre-gl": "^5.24.0",
"react-map-gl": "^8.1.0",
"supercluster": "^8.0.1"

View File

@@ -50,6 +50,7 @@ import {
getTimeFormatterForGranularity,
BinaryQueryObjectFilterClause,
extractTextFromHTML,
TimeGranularity,
} from '@superset-ui/core';
import {
styled,
@@ -309,6 +310,67 @@ function SelectPageSize({
const getNoResultsMessage = (filter: string) =>
filter ? t('No matching records found') : t('No records found');
/**
* Calculates the inclusive/exclusive temporal range for a bucket.
* standard SQL range pattern: [start, end)
*/
function getTimeRangeFromGranularity(
startTime: Date,
granularity: TimeGranularity,
): [Date, Date] {
const time = startTime.getTime();
const date = startTime.getUTCDate();
const month = startTime.getUTCMonth();
const year = startTime.getUTCFullYear();
// Constants
const MS_IN_SECOND = 1000;
const MS_IN_MINUTE = 60 * MS_IN_SECOND;
const MS_IN_HOUR = 60 * MS_IN_MINUTE;
switch (granularity) {
case TimeGranularity.SECOND:
return [startTime, new Date(time + MS_IN_SECOND)];
case TimeGranularity.MINUTE:
return [startTime, new Date(time + MS_IN_MINUTE)];
case TimeGranularity.FIVE_MINUTES:
return [startTime, new Date(time + MS_IN_MINUTE * 5)];
case TimeGranularity.TEN_MINUTES:
return [startTime, new Date(time + MS_IN_MINUTE * 10)];
case TimeGranularity.FIFTEEN_MINUTES:
return [startTime, new Date(time + MS_IN_MINUTE * 15)];
case TimeGranularity.THIRTY_MINUTES:
return [startTime, new Date(time + MS_IN_MINUTE * 30)];
case TimeGranularity.HOUR:
return [startTime, new Date(time + MS_IN_HOUR)];
case TimeGranularity.DAY:
case TimeGranularity.DATE:
return [startTime, new Date(Date.UTC(year, month, date + 1))];
case TimeGranularity.WEEK:
case TimeGranularity.WEEK_STARTING_SUNDAY:
case TimeGranularity.WEEK_STARTING_MONDAY:
return [startTime, new Date(Date.UTC(year, month, date + 7))];
case TimeGranularity.WEEK_ENDING_SATURDAY:
case TimeGranularity.WEEK_ENDING_SUNDAY:
// Week-ending buckets are labeled by the bucket's final day.
return [
new Date(Date.UTC(year, month, date - 6)),
new Date(Date.UTC(year, month, date + 1)),
];
case TimeGranularity.MONTH:
return [startTime, new Date(Date.UTC(year, month + 1, 1))];
case TimeGranularity.QUARTER:
return [
startTime,
new Date(Date.UTC(year, Math.floor(month / 3) * 3 + 3, 1)),
];
case TimeGranularity.YEAR:
return [startTime, new Date(Date.UTC(year + 1, 0, 1))];
default:
return [startTime, new Date(Date.UTC(year, month, date + 1))];
}
}
export default function TableChart<D extends DataRecord = DataRecord>(
props: TableChartTransformedProps<D> & {
sticky?: DataTableProps<D>['sticky'];
@@ -471,7 +533,12 @@ export default function TableChart<D extends DataRecord = DataRecord>(
// so that cross-filters work on the receiving chart
const resolvedCol = columnLabelToNameMap[col] ?? col;
const val = ensureIsArray(updatedFilters?.[col]);
if (!val.length)
if (
!val.length ||
val[0] === null ||
(val[0] instanceof DateWithFormatter &&
val[0].input === null)
)
return {
col: resolvedCol,
op: 'IS NULL' as const,
@@ -578,15 +645,47 @@ export default function TableChart<D extends DataRecord = DataRecord>(
const drillToDetailFilters: BinaryQueryObjectFilterClause[] = [];
filteredColumnsMeta.forEach(col => {
if (!col.isMetric) {
let dataRecordValue = value[col.key];
dataRecordValue = extractTextFromHTML(dataRecordValue);
const dataRecordValue = value[col.key];
drillToDetailFilters.push({
col: col.key,
op: '==',
val: dataRecordValue as string | number | boolean,
formattedVal: formatColumnValue(col, dataRecordValue)[1],
});
// FIX: Explicitly handle NULL values for temporal and non-temporal columns
// DateWithFormatter objects wrap nulls, so we must check both
if (
dataRecordValue == null ||
(dataRecordValue instanceof DateWithFormatter &&
dataRecordValue.input == null)
) {
drillToDetailFilters.push({
col: col.key,
op: 'IS NULL' as any,
val: null,
});
} else if (col.dataType === GenericDataType.Temporal && timeGrain) {
const startTime =
dataRecordValue instanceof Date
? dataRecordValue
: new Date(dataRecordValue as string | number);
const [rangeStartTime, rangeEndTime] =
getTimeRangeFromGranularity(startTime, timeGrain);
const timeRangeValue = `${rangeStartTime.toISOString()} : ${rangeEndTime.toISOString()}`;
drillToDetailFilters.push({
col: col.key,
op: 'TEMPORAL_RANGE',
val: timeRangeValue,
grain: timeGrain,
formattedVal: formatColumnValue(col, dataRecordValue)[1],
});
} else {
// Non-temporal columns use exact match
const sanitizedValue = extractTextFromHTML(dataRecordValue);
drillToDetailFilters.push({
col: col.key,
op: '==',
val: sanitizedValue as string | number | boolean,
formattedVal: formatColumnValue(col, sanitizedValue)[1],
});
}
}
});
onContextMenu(clientX, clientY, {
@@ -600,7 +699,11 @@ export default function TableChart<D extends DataRecord = DataRecord>(
filters: [
{
col: cellPoint.key,
op: '==',
op: (cellPoint.value == null ||
(cellPoint.value instanceof DateWithFormatter &&
cellPoint.value.input == null)
? 'IS NULL'
: '==') as any,
val: extractTextFromHTML(cellPoint.value),
},
],
@@ -615,6 +718,7 @@ export default function TableChart<D extends DataRecord = DataRecord>(
isRawRecords,
filteredColumnsMeta,
getCrossFilterDataMask,
timeGrain,
]);
const getHeaderColumns = useCallback(

View File

@@ -2360,3 +2360,76 @@ describe('plugin-chart-table', () => {
});
});
});
/**
* DRILL-TO-DETAIL FIX VERIFICATION (#23847)
*/
describe('Drill-to-Detail Temporal Range Logic', () => {
const renderChartAndOpenContextMenu = (
timeGrain?: TimeGranularity,
timestampValue?: string | number | null,
) => {
const onContextMenu = jest.fn();
const data = cloneDeep(testData.basic);
if (timestampValue !== undefined) {
data.queriesData[0].data[0].__timestamp = timestampValue;
}
const props = transformProps({
...data,
rawFormData: {
...data.rawFormData,
...(timeGrain ? { time_grain_sqla: timeGrain } : {}),
},
hooks: { onAddFilter: jest.fn(), onContextMenu, setDataMask: jest.fn() },
});
render(<TableChart {...props} sticky={false} />);
const tbody = screen.getAllByRole('rowgroup')[1];
fireEvent.contextMenu(tbody.querySelectorAll('td')[0]);
const [, , { drillToDetail }] = onContextMenu.mock.calls[0];
return drillToDetail.find((f: any) => f.col === '__timestamp');
};
test('uses TEMPORAL_RANGE for monthly grain', () => {
const filter = renderChartAndOpenContextMenu(TimeGranularity.MONTH);
expect(filter.op).toBe('TEMPORAL_RANGE');
expect(filter.val).toContain(
'2020-01-01T12:34:56.000Z : 2020-02-01T00:00:00.000Z',
);
});
test('uses the full bucket for week ending sunday grain', () => {
const filter = renderChartAndOpenContextMenu(
TimeGranularity.WEEK_ENDING_SUNDAY,
'2020-01-05T00:00:00',
);
expect(filter.op).toBe('TEMPORAL_RANGE');
expect(filter.val).toBe(
'2019-12-30T00:00:00.000Z : 2020-01-06T00:00:00.000Z',
);
});
test('uses the full bucket for week ending saturday grain', () => {
const filter = renderChartAndOpenContextMenu(
TimeGranularity.WEEK_ENDING_SATURDAY,
'2020-01-04T00:00:00',
);
expect(filter.op).toBe('TEMPORAL_RANGE');
expect(filter.val).toBe(
'2019-12-29T00:00:00.000Z : 2020-01-05T00:00:00.000Z',
);
});
test('correctly handles NULL values by emitting IS NULL instead of 1970 timestamp', () => {
const filter = renderChartAndOpenContextMenu(TimeGranularity.MONTH, null);
expect(filter.op).toBe('IS NULL');
expect(filter.val).toBeNull();
});
});

View File

@@ -19,15 +19,13 @@
import { SHARED_COLUMN_CONFIG_PROPS } from './constants';
const tokenSeparators =
SHARED_COLUMN_CONFIG_PROPS.d3NumberFormat.tokenSeparators;
const { d3NumberFormat } = SHARED_COLUMN_CONFIG_PROPS;
test('should allow commas in D3 format inputs', () => {
expect(tokenSeparators).toBeDefined();
expect(tokenSeparators).not.toContain(',');
test('should keep D3 format input creatable', () => {
expect(d3NumberFormat.creatable).toBe(true);
});
test('should have correct default token separators', () => {
const expectedSeparators = ['\r\n', '\n', '\t', ';'];
expect(tokenSeparators).toEqual(expectedSeparators);
test('should expose expected D3 format options', () => {
expect(Array.isArray(d3NumberFormat.options)).toBe(true);
expect((d3NumberFormat.options ?? []).length).toBeGreaterThan(0);
});

View File

@@ -58,8 +58,6 @@ const d3NumberFormat: ControlFormItemSpec<'Select'> = {
creatable: true,
minWidth: '14em',
debounceDelay: 500,
// default value tokenSeparators in superset-frontend/packages/superset-ui-core/src/components/Select/constants.ts
tokenSeparators: ['\r\n', '\n', '\t', ';'],
};
const d3TimeFormat: ControlFormItemSpec<'Select'> = {

View File

@@ -36,6 +36,11 @@ from superset.charts.data.dashboard_filter_context import (
get_dashboard_filter_context,
)
from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader
from superset.charts.data.query_context_sidecar import (
DEFAULT_QUERY_CONTEXT_SIDECAR_TIMEOUT,
fetch_query_context_from_sidecar,
QueryContextSidecarError,
)
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.commands.chart.data.create_async_job_command import (
CreateAsyncChartDataJobCommand,
@@ -57,7 +62,7 @@ from superset.constants import (
)
from superset.daos.exceptions import DatasourceNotFound
from superset.exceptions import QueryObjectValidationError, SupersetSecurityException
from superset.extensions import event_logger
from superset.extensions import db, event_logger
from superset.models.sql_lab import Query
from superset.utils import json
from superset.utils.core import (
@@ -65,7 +70,7 @@ from superset.utils.core import (
DatasourceType,
get_user_id,
)
from superset.utils.decorators import logs_context
from superset.utils.decorators import logs_context, transaction
from superset.views.base import CsvResponse, generate_download_headers, XlsxResponse
from superset.views.base_api import statsd_metrics
@@ -74,12 +79,17 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
MISSING_QUERY_CONTEXT_MESSAGE = (
"Chart has no query context saved. Please save the chart again."
)
class ChartDataRestApi(ChartRestApi):
include_route_methods = {"get_data", "data", "data_from_cache"}
@expose("/<int:pk>/data/", methods=("GET",))
@protect()
@transaction()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data",
@@ -161,24 +171,50 @@ class ChartDataRestApi(ChartRestApi):
if not chart:
return self.response_404()
try:
json_body = json.loads(chart.query_context)
except (TypeError, json.JSONDecodeError):
json_body = None
force_refresh = self._is_force_refresh_requested()
sidecar_url = app.config.get("QUERY_CONTEXT_SIDECAR_URL")
should_refresh_query_context = force_refresh and bool(sidecar_url)
json_body = (
None
if should_refresh_query_context
else self._load_saved_query_context(chart)
)
if json_body is None:
return self.response_400(
message=_(
"Chart has no query context saved. Please save the chart again."
)
if not chart.params:
return self.response_400(message=_(MISSING_QUERY_CONTEXT_MESSAGE))
if not sidecar_url:
return self.response_400(message=_(MISSING_QUERY_CONTEXT_MESSAGE))
try:
form_data = json.loads(chart.params)
except (TypeError, json.JSONDecodeError):
return self.response_400(message=_(MISSING_QUERY_CONTEXT_MESSAGE))
timeout = app.config.get(
"QUERY_CONTEXT_SIDECAR_TIMEOUT",
DEFAULT_QUERY_CONTEXT_SIDECAR_TIMEOUT,
)
try:
json_body = fetch_query_context_from_sidecar(
sidecar_url=sidecar_url,
form_data=form_data,
timeout=timeout,
)
except QueryContextSidecarError as ex:
return self.response_502(message=str(ex))
chart.query_context = json.dumps(json_body)
chart.last_saved_at = datetime.now()
db.session.flush()
# override saved query context
json_body["result_format"] = request.args.get(
"format", ChartDataResultFormat.JSON
)
json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL)
json_body["force"] = request.args.get("force")
json_body["force"] = force_refresh
# Apply dashboard filter context when filters_dashboard_id is provided
dashboard_filter_context: DashboardFilterContext | None = None
@@ -282,6 +318,18 @@ class ChartDataRestApi(ChartRestApi):
dashboard_filter_context=dashboard_filter_context,
)
def _is_force_refresh_requested(self) -> bool:
return request.args.get("force") in {"1", "true", "True", "force"}
def _load_saved_query_context(self, chart: Any) -> dict[str, Any] | None:
try:
json_body = json.loads(chart.query_context)
except (TypeError, json.JSONDecodeError):
return None
if isinstance(json_body, dict):
return json_body
return None
@expose("/data", methods=("POST",))
@protect()
@statsd_metrics

View File

@@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import Any
import requests
from flask import current_app as app
from superset.utils import json
logger = logging.getLogger(__name__)
DEFAULT_QUERY_CONTEXT_SIDECAR_TIMEOUT = 30
class QueryContextSidecarError(Exception):
"""Raised when query context cannot be generated via sidecar."""
def maybe_generate_query_context(model: Any, params_json: str | None) -> None:
"""Best-effort generation of query_context via the sidecar service.
Sets ``model.query_context`` on success. Failures are logged but never
re-raised so chart saves are not blocked.
"""
sidecar_url = app.config.get("QUERY_CONTEXT_SIDECAR_URL")
if not sidecar_url or not params_json:
return
try:
form_data = json.loads(params_json)
except (TypeError, json.JSONDecodeError):
logger.warning("Could not parse chart params for sidecar query context")
return
timeout = app.config.get(
"QUERY_CONTEXT_SIDECAR_TIMEOUT",
DEFAULT_QUERY_CONTEXT_SIDECAR_TIMEOUT,
)
try:
result = fetch_query_context_from_sidecar(
sidecar_url=sidecar_url,
form_data=form_data,
timeout=timeout,
)
model.query_context = json.dumps(result)
except QueryContextSidecarError:
logger.warning(
"Failed to generate query context via sidecar for chart %s",
getattr(model, "id", "?"),
)
except Exception:
logger.warning(
"Unexpected error generating query context via sidecar for chart %s",
getattr(model, "id", "?"),
exc_info=True,
)
def fetch_query_context_from_sidecar(
*,
sidecar_url: str,
form_data: dict[str, Any],
timeout: int,
) -> dict[str, Any]:
endpoint = f"{sidecar_url.rstrip('/')}/api/v1/build-query-context"
try:
response = requests.post(
endpoint,
json={"form_data": form_data},
timeout=timeout,
)
except requests.RequestException as ex:
raise QueryContextSidecarError("Query context sidecar unavailable") from ex
if response.status_code != 200:
raise QueryContextSidecarError("Query context sidecar error")
try:
payload = response.json()
except ValueError as ex:
raise QueryContextSidecarError(
"Query context sidecar returned invalid response"
) from ex
query_context = payload.get("query_context")
if not isinstance(query_context, dict):
raise QueryContextSidecarError(
"Query context sidecar returned invalid response"
)
return query_context

View File

@@ -24,6 +24,7 @@ from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from superset import security_manager
from superset.charts.data.query_context_sidecar import maybe_generate_query_context
from superset.commands.base import BaseCommand, CreateMixin
from superset.commands.chart.exceptions import (
ChartCreateFailedError,
@@ -48,7 +49,12 @@ class CreateChartCommand(CreateMixin, BaseCommand):
self.validate()
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
return ChartDAO.create(attributes=self._properties)
chart = ChartDAO.create(attributes=self._properties)
if not self._properties.get("query_context"):
maybe_generate_query_context(chart, self._properties.get("params"))
return chart
def validate(self) -> None:
exceptions = []

View File

@@ -24,6 +24,7 @@ from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from superset import security_manager
from superset.charts.data.query_context_sidecar import maybe_generate_query_context
from superset.commands.base import BaseCommand, UpdateMixin
from superset.commands.chart.exceptions import (
ChartForbiddenError,
@@ -70,7 +71,16 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
return ChartDAO.update(self._model, self._properties)
chart = ChartDAO.update(self._model, self._properties)
if (
"params" in self._properties
and not self._properties.get("query_context")
and not self._properties.get("query_context_generation")
):
maybe_generate_query_context(chart, self._properties["params"])
return chart
def _validate_new_dashboard_access(
self, requested_dashboards: list[Dashboard], exceptions: list[Exception]

View File

@@ -22,7 +22,6 @@ from datetime import datetime
from pprint import pformat
from typing import Any, NamedTuple, TYPE_CHECKING
from flask import g
from flask_babel import gettext as _
from jinja2.exceptions import TemplateError
from pandas import DataFrame
@@ -38,6 +37,7 @@ from superset.extensions import event_logger
from superset.sql.parse import sanitize_clause, transpile_to_dialect
from superset.superset_typing import Column, Metric, OrderBy, QueryObjectDict
from superset.utils import json, pandas_postprocessing
from superset.utils.cache_keys import add_impersonation_cache_key_if_needed
from superset.utils.core import (
DTTM_ALIAS,
find_duplicates,
@@ -479,24 +479,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
# or if the CACHE_QUERY_BY_USER flag is on or per_user_caching is enabled on
# the database
try:
database = self.datasource.database # type: ignore
extra = json.loads(database.extra or "{}")
if (
(
feature_flag_manager.is_feature_enabled("CACHE_IMPERSONATION")
and database.impersonate_user
)
or feature_flag_manager.is_feature_enabled("CACHE_QUERY_BY_USER")
or extra.get("per_user_caching", False)
):
if key := database.db_engine_spec.get_impersonation_key(
getattr(g, "user", None)
):
logger.debug(
"Adding impersonation key to QueryObject cache dict: %s", key
)
cache_dict["impersonation_key"] = key
add_impersonation_cache_key_if_needed(self.datasource.database, cache_dict) # type: ignore
except AttributeError:
# datasource or database do not exist
pass

View File

@@ -2322,6 +2322,11 @@ GLOBAL_ASYNC_QUERIES_POLLING_DELAY = int(
)
GLOBAL_ASYNC_QUERIES_WEBSOCKET_URL = "ws://127.0.0.1:8080/"
# Optional internal service URL used to generate chart query_context from form_data
# when query_context is missing (or refresh is explicitly forced).
QUERY_CONTEXT_SIDECAR_URL: str | None = None
QUERY_CONTEXT_SIDECAR_TIMEOUT = 30
# Global async queries cache backend configuration options:
# - Set 'CACHE_TYPE' to 'RedisCache' for RedisCacheBackend.
# - Set 'CACHE_TYPE' to 'RedisSentinelCache' for RedisSentinelCacheBackend.

View File

@@ -590,7 +590,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# Driver-specific params to be included in the `get_oauth2_token` request body
oauth2_additional_token_request_params: dict[str, Any] = {}
# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
oauth2_exception: type[Exception] | tuple[type[Exception], ...] = (
OAuth2RedirectError
)
# Does the query id related to the connection?
# The default value is True, which means that the query id is determined when

View File

@@ -31,6 +31,7 @@ from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
from requests import Session
from shillelagh.adapters.api.gsheets.lib import SCOPES
from shillelagh.exceptions import UnauthenticatedError
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
@@ -40,7 +41,7 @@ from superset.databases.schemas import encrypted_field_properties, EncryptedStri
from superset.db_engine_specs.base import DatabaseCategory
from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.exceptions import OAuth2TokenRefreshError, SupersetException
from superset.utils import json
from superset.utils.oauth2 import get_oauth2_access_token
@@ -151,6 +152,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
"https://accounts.google.com/o/oauth2/v2/auth"
)
oauth2_token_request_uri = "https://oauth2.googleapis.com/token" # noqa: S105
oauth2_exception = (UnauthenticatedError, OAuth2TokenRefreshError)
@classmethod
def get_oauth2_authorization_uri(

View File

@@ -62,6 +62,7 @@ Dataset Management:
- list_datasets: List datasets with advanced filters (1-based pagination)
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
- create_virtual_dataset: Save a SQL query as a virtual dataset for charting
- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart
Chart Management:
- list_charts: List charts with advanced filters (1-based pagination)
@@ -164,6 +165,17 @@ Use created_by_me for authorship, owned_by_me for edit ownership, or both
together for the union. All flags can be combined with 'filters' but not
with 'search'.
To query a dataset's semantic layer (metrics, dimensions):
1. list_datasets(request={{}}) -> find a dataset
2. get_dataset_info(request={{"identifier": <id>}}) -> examine columns AND metrics
3. query_dataset(request={{
"dataset_id": <id>,
"metrics": ["count", "avg_revenue"],
"columns": ["category"],
"time_range": "Last 7 days",
"row_limit": 100
}}) -> returns tabular data using saved metrics and dimensions
To explore data with SQL:
1. list_datasets(request={{}}) -> find a dataset and note its database_id
2. execute_sql(request={{"database_id": <id>, "sql": "SELECT ..."}})
@@ -520,6 +532,7 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
create_virtual_dataset,
get_dataset_info,
list_datasets,
query_dataset,
)
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
generate_explore_link,

View File

@@ -632,6 +632,15 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
@functools.wraps(tool_func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
with _get_app_context_manager():
# Clear any stale thread-local SQLAlchemy session before user lookup.
# Thread pool workers reuse threads across requests; db.session is
# scoped by thread (not ContextVar), so a prior request's session may
# still be bound to a different tenant's DB engine. Removing it here
# ensures the next DB access creates a fresh session bound to the
# correct engine for the current request.
from superset.extensions import db
db.session.remove()
user = _setup_user_context()
# No Flask context - this is a FastMCP internal operation

View File

@@ -70,6 +70,8 @@ SORTABLE_CHART_COLUMNS = [
"created_on",
]
_DEFAULT_LIST_CHARTS_REQUEST = ListChartsRequest()
@tool(
tags=["core"],
@@ -81,7 +83,8 @@ SORTABLE_CHART_COLUMNS = [
),
)
async def list_charts(
request: ListChartsRequest, ctx: Context
request: ListChartsRequest | None = None,
ctx: Context = None,
) -> ChartList | ChartError:
"""List charts with filtering and search.
@@ -91,6 +94,7 @@ async def list_charts(
Sortable columns for order_column: id, slice_name, viz_type, description,
changed_on, created_on
"""
request = request or _DEFAULT_LIST_CHARTS_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing charts: page=%s, page_size=%s, search=%s"
% (

View File

@@ -65,6 +65,8 @@ SORTABLE_DASHBOARD_COLUMNS = [
"created_on",
]
_DEFAULT_LIST_DASHBOARDS_REQUEST = ListDashboardsRequest()
@tool(
tags=["core"],
@@ -76,7 +78,8 @@ SORTABLE_DASHBOARD_COLUMNS = [
),
)
async def list_dashboards(
request: ListDashboardsRequest, ctx: Context
request: ListDashboardsRequest | None = None,
ctx: Context = None,
) -> DashboardList:
"""List dashboards with filtering and search. Returns dashboard metadata
including title, slug, URL, and last modified time. Use select_columns to
@@ -85,6 +88,7 @@ async def list_dashboards(
Sortable columns for order_column: id, dashboard_title, slug, published,
changed_on, created_on
"""
request = request or _DEFAULT_LIST_DASHBOARDS_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing dashboards: page=%s, page_size=%s, search=%s"
% (

View File

@@ -36,10 +36,13 @@ from pydantic import (
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
from superset.mcp_service.common.cache_schemas import (
CacheStatus,
CreatedByMeMixin,
MetadataCacheControl,
OwnedByMeMixin,
QueryCacheControl,
)
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from superset.mcp_service.privacy import filter_user_directory_fields
@@ -393,6 +396,146 @@ class CreateVirtualDatasetResponse(BaseModel):
)
VALID_FILTER_OPS = Literal[
"==",
"!=",
">",
"<",
">=",
"<=",
"LIKE",
"NOT LIKE",
"ILIKE",
"NOT ILIKE",
"IN",
"NOT IN",
"IS NULL",
"IS NOT NULL",
"IS TRUE",
"IS FALSE",
"TEMPORAL_RANGE",
]
class QueryDatasetFilter(BaseModel):
"""A single filter condition for dataset queries."""
col: str = Field(..., description="Column name to filter on")
op: VALID_FILTER_OPS = Field(
...,
description=(
'Filter operator. Use "==" for equals, "!=" for not equals, '
'"IN" / "NOT IN" for membership, "IS NULL" / "IS NOT NULL", '
'"LIKE" for pattern matching, "TEMPORAL_RANGE" for time filters.'
),
)
val: Any = Field(
default=None,
description="Filter value (omit for IS NULL/IS NOT NULL)",
)
class QueryDatasetRequest(QueryCacheControl):
"""Request schema for query_dataset tool."""
dataset_id: int | str = Field(
...,
description="Dataset identifier — numeric ID or UUID string.",
)
metrics: List[str] = Field(
default_factory=list,
description=(
"Saved metric names to compute (e.g. ['count', 'avg_revenue']). "
"Use get_dataset_info to discover available metrics."
),
)
columns: List[str] = Field(
default_factory=list,
description=(
"Column/dimension names for GROUP BY or SELECT "
"(e.g. ['category', 'region']). "
"Use get_dataset_info to discover available columns."
),
)
filters: List[QueryDatasetFilter] = Field(
default_factory=list,
description=(
'Filter conditions (e.g. [{"col": "status", "op": "==", "val": "active"}]).'
),
)
time_range: str | None = Field(
default=None,
description=(
"Time range filter (e.g. 'Last 7 days', 'Last month', "
"'2024-01-01 : 2024-12-31'). Requires a temporal column "
"on the dataset."
),
)
time_column: str | None = Field(
default=None,
description=(
"Temporal column to apply time_range to. "
"Defaults to the dataset's main datetime column."
),
)
order_by: List[str] | None = Field(
default=None,
description="Column or metric names to sort results by.",
)
order_desc: bool = Field(
default=True,
description="Sort descending (True) or ascending (False).",
)
row_limit: int = Field(
default=1000,
ge=1,
le=50000,
description="Maximum number of rows to return (default 1000, max 50000).",
)
@model_validator(mode="after")
def validate_metrics_or_columns(self) -> "QueryDatasetRequest":
"""At least one of metrics or columns must be provided."""
if not self.metrics and not self.columns:
raise ValueError(
"At least one of 'metrics' or 'columns' must be provided. "
"Use get_dataset_info to discover available metrics and columns."
)
return self
class QueryDatasetResponse(BaseModel):
"""Response schema for query_dataset tool."""
model_config = ConfigDict(ser_json_timedelta="iso8601")
dataset_id: int = Field(..., description="Dataset ID")
dataset_name: str = Field(..., description="Dataset name")
columns: List[DataColumn] = Field(
default_factory=list, description="Column metadata for returned data"
)
data: List[Dict[str, Any]] = Field(
default_factory=list, description="Query result rows"
)
row_count: int = Field(0, description="Number of rows returned")
total_rows: int | None = Field(
None, description="Total row count from the query engine"
)
summary: str = Field("", description="Human-readable summary of the results")
performance: PerformanceMetadata | None = Field(
None, description="Query performance metadata"
)
cache_status: CacheStatus | None = Field(
None, description="Cache hit/miss information"
)
applied_filters: List[QueryDatasetFilter] = Field(
default_factory=list, description="Filters that were applied to the query"
)
warnings: List[str] = Field(
default_factory=list, description="Any warnings encountered during execution"
)
def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None:
"""Parse a field that may be stored as a JSON string into a dict."""
value = getattr(obj, field_name, None)

View File

@@ -18,9 +18,11 @@
from .create_virtual_dataset import create_virtual_dataset
from .get_dataset_info import get_dataset_info
from .list_datasets import list_datasets
from .query_dataset import query_dataset
__all__ = [
"create_virtual_dataset",
"list_datasets",
"get_dataset_info",
"query_dataset",
]

View File

@@ -0,0 +1,489 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
MCP tool: query_dataset
Query a dataset using its semantic layer (saved metrics, calculated columns,
dimensions) without requiring a saved chart.
"""
import difflib
import logging
import time
from typing import Any
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import joinedload, subqueryload
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
from superset.mcp_service.dataset.schemas import (
DatasetError,
QueryDatasetFilter,
QueryDatasetRequest,
QueryDatasetResponse,
)
from superset.mcp_service.privacy import (
DATA_MODEL_METADATA_ERROR_TYPE,
requires_data_model_metadata_access,
user_can_view_data_model_metadata,
)
from superset.mcp_service.utils import _is_uuid
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
from superset.mcp_service.utils.oauth2_utils import build_oauth2_redirect_message
logger = logging.getLogger(__name__)
def _resolve_dataset(identifier: int | str, eager_options: list[Any]) -> Any | None:
"""Resolve a dataset by int ID or UUID string.
Replicates the identifier resolution logic from ModelGetInfoCore._find_object().
"""
from superset.daos.dataset import DatasetDAO
opts = eager_options or None
if isinstance(identifier, int):
return DatasetDAO.find_by_id(identifier, query_options=opts)
# Try parsing as int
try:
id_val = int(identifier)
return DatasetDAO.find_by_id(id_val, query_options=opts)
except (ValueError, TypeError):
pass
# Try UUID
if _is_uuid(str(identifier)):
return DatasetDAO.find_by_id(identifier, id_column="uuid", query_options=opts)
return None
def _validate_names(
requested: list[str],
valid: set[str],
kind: str,
) -> list[str]:
"""Return list of error messages for names not found in *valid*.
Includes close-match suggestions when available.
"""
errors: list[str] = []
for name in requested:
if name not in valid:
suggestions = difflib.get_close_matches(name, valid, n=3, cutoff=0.6)
msg = f"Unknown {kind}: '{name}'"
if suggestions:
msg += f". Did you mean: {', '.join(suggestions)}?"
errors.append(msg)
return errors
@requires_data_model_metadata_access
@tool(
tags=["data"],
class_permission_name="Dataset",
annotations=ToolAnnotations(
title="Query dataset",
readOnlyHint=True,
destructiveHint=False,
),
)
async def query_dataset( # noqa: C901
request: QueryDatasetRequest, ctx: Context
) -> QueryDatasetResponse | DatasetError:
"""Query a dataset using its semantic layer (saved metrics, dimensions, filters).
Returns tabular data without requiring a saved chart. Use this when you want
to compute saved metrics, group by dimensions, or apply filters directly
against a dataset's curated semantic layer.
Workflow:
1. list_datasets -> find a dataset
2. get_dataset_info -> discover available columns and metrics
3. query_dataset -> query using metric names and column names
Example:
```json
{
"dataset_id": 123,
"metrics": ["count", "avg_revenue"],
"columns": ["product_category"],
"time_range": "Last 7 days",
"row_limit": 100
}
```
"""
await ctx.info(
"Starting dataset query: dataset_id=%s, metrics=%s, columns=%s, "
"row_limit=%s"
% (
request.dataset_id,
request.metrics,
request.columns,
request.row_limit,
)
)
try:
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.common.query_context_factory import QueryContextFactory
from superset.connectors.sqla.models import SqlaTable
# ------------------------------------------------------------------
# Step 1: Check data-model metadata access BEFORE the dataset lookup.
# Doing this first prevents leaking dataset existence — restricted
# users always receive DataModelMetadataRestricted, never NotFound.
# The decorator hides this tool from search; this check enforces
# direct calls that bypass tool discovery.
# ------------------------------------------------------------------
if not user_can_view_data_model_metadata():
await ctx.warning("Dataset metadata access blocked by privacy controls")
return DatasetError.create(
error=(
"You don't have permission to access dataset details for your role."
),
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
)
# ------------------------------------------------------------------
# Step 2: Resolve dataset
# ------------------------------------------------------------------
await ctx.report_progress(1, 5, "Looking up dataset")
eager_options = [
subqueryload(SqlaTable.columns),
subqueryload(SqlaTable.metrics),
joinedload(SqlaTable.database),
]
with event_logger.log_context(action="mcp.query_dataset.lookup"):
dataset = _resolve_dataset(request.dataset_id, eager_options)
if dataset is None:
await ctx.error("Dataset not found: identifier=%s" % (request.dataset_id,))
return DatasetError.create(
error=f"No dataset found with identifier: {request.dataset_id}",
error_type="NotFound",
)
dataset_name = getattr(dataset, "table_name", None) or f"Dataset {dataset.id}"
await ctx.info(
"Dataset found: id=%s, name=%s, columns=%s, metrics=%s"
% (
dataset.id,
dataset_name,
len(dataset.columns),
len(dataset.metrics),
)
)
# ------------------------------------------------------------------
# Step 2: Validate requested columns and metrics
# ------------------------------------------------------------------
await ctx.report_progress(2, 5, "Validating columns and metrics")
valid_columns = {c.column_name for c in dataset.columns}
valid_metrics = {m.metric_name for m in dataset.metrics}
validation_errors: list[str] = []
validation_errors.extend(
_validate_names(request.columns, valid_columns, "column")
)
validation_errors.extend(
_validate_names(request.metrics, valid_metrics, "metric")
)
# Validate filter column names against dataset columns
filter_cols = [f.col for f in request.filters]
validation_errors.extend(
_validate_names(filter_cols, valid_columns, "filter column")
)
# Validate order_by names against columns + metrics
if request.order_by:
valid_orderby = valid_columns | valid_metrics
validation_errors.extend(
_validate_names(request.order_by, valid_orderby, "order_by")
)
if validation_errors:
error_msg = "; ".join(validation_errors)
await ctx.error("Validation failed: %s" % (error_msg,))
return DatasetError.create(
error=error_msg,
error_type="ValidationError",
)
# ------------------------------------------------------------------
# Step 3: Build filters and time range
# ------------------------------------------------------------------
warnings: list[str] = []
query_filters: list[dict[str, Any]] = [
{"col": f.col, "op": f.op, "val": f.val} for f in request.filters
]
# Track all applied filters (including synthesized ones) for the response.
effective_filters: list[QueryDatasetFilter] = list(request.filters)
granularity: str | None = None
if request.time_range:
temporal_col = request.time_column or getattr(
dataset, "main_dttm_col", None
)
if not temporal_col:
await ctx.error("time_range provided but no temporal column available")
return DatasetError.create(
error=(
"time_range was provided but no temporal column is available. "
"Either set time_column explicitly or ensure the dataset has "
"a main datetime column configured."
),
error_type="ValidationError",
)
# Validate that the temporal column actually exists on the dataset
if temporal_col not in valid_columns:
await ctx.error("time_column '%s' not found on dataset" % temporal_col)
return DatasetError.create(
error=(
f"time_column '{temporal_col}' does not exist on this dataset."
),
error_type="ValidationError",
)
# Warn if the chosen temporal column isn't marked as datetime
dttm_cols = {c.column_name for c in dataset.columns if c.is_dttm}
if temporal_col not in dttm_cols:
warnings.append(
f"Column '{temporal_col}' is not marked as a datetime "
f"column on this dataset. Time filtering may not work "
f"as expected."
)
query_filters.append(
{
"col": temporal_col,
"op": "TEMPORAL_RANGE",
"val": request.time_range,
}
)
effective_filters.append(
QueryDatasetFilter(
col=temporal_col,
op="TEMPORAL_RANGE",
val=request.time_range,
)
)
granularity = temporal_col
await ctx.debug(
"Time filter: column=%s, range=%s" % (temporal_col, request.time_range)
)
# ------------------------------------------------------------------
# Step 4: Build query dict
# ------------------------------------------------------------------
await ctx.report_progress(3, 5, "Building query")
query_dict: dict[str, Any] = {
"filters": query_filters,
"columns": request.columns,
"metrics": request.metrics,
"row_limit": request.row_limit,
"order_desc": request.order_desc,
}
if granularity:
query_dict["granularity"] = granularity
if request.order_by:
# OrderBy = tuple[Metric | Column, bool] where bool is ascending
query_dict["orderby"] = [
(col, not request.order_desc) for col in request.order_by
]
await ctx.debug("Query dict keys: %s" % (sorted(query_dict.keys()),))
# ------------------------------------------------------------------
# Step 5: Create QueryContext and execute
# ------------------------------------------------------------------
await ctx.report_progress(4, 5, "Executing query")
start_time = time.time()
with event_logger.log_context(action="mcp.query_dataset.execute"):
factory = QueryContextFactory()
# datasource_type is "table" because this tool queries SqlaTable
# datasets (Superset's built-in semantic layer). External semantic
# layers (dbt, Snowflake Cortex, etc.) use "semantic_view" and have
# a different query path — see SemanticView + mapper.py.
query_context = factory.create(
datasource={"id": dataset.id, "type": "table"},
queries=[query_dict],
form_data={},
force=not request.use_cache or request.force_refresh,
custom_cache_timeout=request.cache_timeout,
)
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
query_duration_ms = int((time.time() - start_time) * 1000)
if not result or "queries" not in result or len(result["queries"]) == 0:
await ctx.warning("Query returned no results for dataset %s" % dataset.id)
return DatasetError.create(
error="Query returned no results.",
error_type="EmptyQuery",
)
# ------------------------------------------------------------------
# Step 6: Format response
# ------------------------------------------------------------------
await ctx.report_progress(5, 5, "Formatting results")
query_result = result["queries"][0]
data = query_result.get("data", [])
raw_columns = query_result.get("colnames", [])
if not data:
return QueryDatasetResponse(
dataset_id=dataset.id,
dataset_name=dataset_name,
columns=[],
data=[],
row_count=0,
total_rows=0,
summary=f"Query on '{dataset_name}' returned no data.",
performance=PerformanceMetadata(
query_duration_ms=query_duration_ms,
cache_status="no_data",
),
cache_status=get_cache_status_from_result(
query_result, force_refresh=request.force_refresh
),
applied_filters=effective_filters,
warnings=warnings,
)
# Build column metadata in a single pass per column.
# Cap stats computation at STATS_SAMPLE rows to avoid O(rows*cols)
# overhead on large result sets (row_limit allows up to 50k).
stats_sample_size = 5000
stats_rows = data[:stats_sample_size]
columns_meta: list[DataColumn] = []
for col_name in raw_columns:
sample_values = [
row.get(col_name) for row in data[:3] if row.get(col_name) is not None
]
data_type = "string"
if sample_values:
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
# Compute null_count and unique non-null values in one pass
null_count = 0
unique_vals: set[str] = set()
for row in stats_rows:
val = row.get(col_name)
if val is None:
null_count += 1
else:
unique_vals.add(str(val))
columns_meta.append(
DataColumn(
name=col_name,
display_name=col_name.replace("_", " ").title(),
data_type=data_type,
sample_values=sample_values[:3],
null_count=null_count,
unique_count=len(unique_vals),
)
)
cache_status = get_cache_status_from_result(
query_result, force_refresh=request.force_refresh
)
cache_label = "cached" if cache_status and cache_status.cache_hit else "fresh"
summary = (
f"Dataset '{dataset_name}': {len(data)} rows, "
f"{len(raw_columns)} columns ({cache_label})."
)
await ctx.info(
"Query complete: rows=%s, columns=%s, duration=%sms"
% (len(data), len(raw_columns), query_duration_ms)
)
return QueryDatasetResponse(
dataset_id=dataset.id,
dataset_name=dataset_name,
columns=columns_meta,
data=data,
row_count=len(data),
total_rows=query_result.get("rowcount"),
summary=summary,
performance=PerformanceMetadata(
query_duration_ms=query_duration_ms,
cache_status=cache_label,
),
cache_status=cache_status,
applied_filters=effective_filters,
warnings=warnings,
)
except OAuth2RedirectError as exc:
redirect_msg = build_oauth2_redirect_message(exc)
await ctx.error("OAuth2 redirect required: %s" % (redirect_msg,))
return DatasetError.create(
error=redirect_msg,
error_type="OAuth2Redirect",
)
except OAuth2Error as exc:
await ctx.error("OAuth2 error: %s" % (str(exc),))
return DatasetError.create(
error=f"OAuth2 authentication error: {exc}",
error_type="OAuth2Error",
)
except (CommandException, SupersetException) as exc:
await ctx.error("Query failed: %s" % (str(exc),))
return DatasetError.create(
error=f"Query execution failed: {exc}",
error_type="QueryError",
)
except SQLAlchemyError as exc:
await ctx.error("Database error: %s" % (str(exc),))
return DatasetError.create(
error=f"Database error: {exc}",
error_type="DatabaseError",
)
except Exception as exc:
logger.exception(
"Unexpected error while querying dataset: %s: %s",
type(exc).__name__,
str(exc),
)
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
return DatasetError.create(
error="An unexpected error occurred while querying the dataset.",
error_type="UnexpectedError",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import Any, TYPE_CHECKING
from flask import g
from superset import feature_flag_manager
from superset.utils.json import loads as json_loads
if TYPE_CHECKING:
from superset.models.core import Database
logger = logging.getLogger(__name__)
def add_impersonation_cache_key_if_needed(
database: Database,
cache_dict: dict[str, Any],
) -> None:
"""
Add a per-user cache-key when the DB connection is configured for
per-user caching, no-op otherwise.
"""
extra = json_loads(database.extra or "{}")
if (
(
feature_flag_manager.is_feature_enabled("CACHE_IMPERSONATION")
and database.impersonate_user
)
or feature_flag_manager.is_feature_enabled("CACHE_QUERY_BY_USER")
or extra.get("per_user_caching", False)
):
if key := database.db_engine_spec.get_impersonation_key(
getattr(g, "user", None)
):
logger.debug("Adding impersonation key to cache dict: %s", key)
cache_dict["impersonation_key"] = key

View File

@@ -65,6 +65,7 @@ from superset.superset_typing import (
)
from superset.utils import core as utils, csv, json
from superset.utils.cache import set_and_log_cache
from superset.utils.cache_keys import add_impersonation_cache_key_if_needed
from superset.utils.core import (
apply_max_row_limit,
DateColumn,
@@ -472,6 +473,16 @@ class BaseViz: # pylint: disable=too-many-public-methods
cache_dict["extra_cache_keys"] = self.datasource.get_extra_cache_keys(query_obj)
cache_dict["rls"] = security_manager.get_rls_cache_key(self.datasource)
cache_dict["changed_on"] = self.datasource.changed_on
# Add an impersonation key to cache if impersonation is enabled on the db
# or if the CACHE_QUERY_BY_USER flag is on or per_user_caching is enabled on
# the database
try:
add_impersonation_cache_key_if_needed(self.datasource.database, cache_dict)
except AttributeError:
# datasource or database do not exist
pass
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hash_from_str(json_data)

View File

@@ -1180,6 +1180,107 @@ class TestGetChartDataApi(BaseTestChartDataApi):
"message": "Chart has no query context saved. Please save the chart again."
}
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_config({"QUERY_CONTEXT_SIDECAR_URL": "http://sidecar.internal"})
@mock.patch("superset.charts.data.api.ChartDataRestApi._get_data_response")
@mock.patch("superset.charts.data.api.ChartDataCommand.validate")
@mock.patch(
"superset.charts.data.api.ChartDataRestApi._create_query_context_from_form"
)
@mock.patch("superset.charts.data.api.fetch_query_context_from_sidecar")
def test_get_data_fetches_missing_query_context_from_sidecar(
self,
mock_fetch_query_context_from_sidecar,
mock_create_query_context_from_form,
mock_validate,
mock_get_data_response,
):
chart = db.session.query(Slice).filter_by(slice_name="Genders").one()
chart.query_context = None
db.session.commit()
sidecar_query_context = {
"datasource": {"id": chart.table.id, "type": "table"},
"force": False,
"queries": [],
"form_data": chart.form_data,
"result_format": "json",
"result_type": "full",
}
mock_fetch_query_context_from_sidecar.return_value = sidecar_query_context
mock_create_query_context_from_form.return_value = mock.MagicMock()
mock_validate.return_value = None
mock_get_data_response.return_value = Response(
response="{}",
status=200,
mimetype="application/json",
)
rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data")
assert rv.status_code == 200
mock_fetch_query_context_from_sidecar.assert_called_once()
db.session.refresh(chart)
assert json.loads(chart.query_context or "{}").get("datasource") == {
"id": chart.table.id,
"type": "table",
}
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_config({"QUERY_CONTEXT_SIDECAR_URL": "http://sidecar.internal"})
@mock.patch("superset.charts.data.api.ChartDataRestApi._get_data_response")
@mock.patch("superset.charts.data.api.ChartDataCommand.validate")
@mock.patch(
"superset.charts.data.api.ChartDataRestApi._create_query_context_from_form"
)
@mock.patch("superset.charts.data.api.fetch_query_context_from_sidecar")
def test_get_data_force_refreshes_query_context_from_sidecar(
self,
mock_fetch_query_context_from_sidecar,
mock_create_query_context_from_form,
mock_validate,
mock_get_data_response,
):
chart = db.session.query(Slice).filter_by(slice_name="Genders").one()
chart.query_context = json.dumps(
{
"datasource": {"id": chart.table.id, "type": "table"},
"force": False,
"queries": [{"metrics": ["sum__num"]}],
"result_format": "json",
"result_type": "full",
}
)
db.session.commit()
refreshed_query_context = {
"datasource": {"id": chart.table.id, "type": "table"},
"force": False,
"queries": [{"metrics": ["count"]}],
"form_data": chart.form_data,
"result_format": "json",
"result_type": "full",
}
mock_fetch_query_context_from_sidecar.return_value = refreshed_query_context
mock_create_query_context_from_form.return_value = mock.MagicMock()
mock_validate.return_value = None
mock_get_data_response.return_value = Response(
response="{}",
status=200,
mimetype="application/json",
)
rv = self.get_assert_metric(
f"api/v1/chart/{chart.id}/data/?force=true",
"get_data",
)
assert rv.status_code == 200
mock_fetch_query_context_from_sidecar.assert_called_once()
db.session.refresh(chart)
persisted = json.loads(chart.query_context or "{}")
assert persisted.get("queries") == [{"metrics": ["count"]}]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_get(self):
"""

View File

@@ -0,0 +1,214 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Any
from unittest import mock
import pytest
import requests
from superset.charts.data.query_context_sidecar import (
fetch_query_context_from_sidecar,
maybe_generate_query_context,
QueryContextSidecarError,
)
@mock.patch("superset.charts.data.query_context_sidecar.requests.post")
def test_fetch_query_context_from_sidecar_success(mock_post: mock.MagicMock) -> None:
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"query_context": {"foo": "bar"}}
payload = fetch_query_context_from_sidecar(
sidecar_url="http://sidecar.internal",
form_data={"viz_type": "pie"},
timeout=15,
)
assert payload == {"foo": "bar"}
mock_post.assert_called_once_with(
"http://sidecar.internal/api/v1/build-query-context",
json={"form_data": {"viz_type": "pie"}},
timeout=15,
)
@mock.patch("superset.charts.data.query_context_sidecar.requests.post")
def test_fetch_query_context_from_sidecar_connection_error(
mock_post: mock.MagicMock,
) -> None:
mock_post.side_effect = requests.RequestException()
with pytest.raises(QueryContextSidecarError, match="sidecar unavailable"):
fetch_query_context_from_sidecar(
sidecar_url="http://sidecar.internal",
form_data={"viz_type": "pie"},
timeout=15,
)
@mock.patch("superset.charts.data.query_context_sidecar.requests.post")
def test_fetch_query_context_from_sidecar_bad_status(mock_post: mock.MagicMock) -> None:
mock_post.return_value.status_code = 500
with pytest.raises(QueryContextSidecarError, match="sidecar error"):
fetch_query_context_from_sidecar(
sidecar_url="http://sidecar.internal",
form_data={"viz_type": "pie"},
timeout=15,
)
@mock.patch("superset.charts.data.query_context_sidecar.requests.post")
def test_fetch_query_context_from_sidecar_invalid_payload(
mock_post: mock.MagicMock,
) -> None:
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"not_query_context": {}}
with pytest.raises(QueryContextSidecarError, match="invalid response"):
fetch_query_context_from_sidecar(
sidecar_url="http://sidecar.internal",
form_data={"viz_type": "pie"},
timeout=15,
)
# ---------------------------------------------------------------------------
# Tests for maybe_generate_query_context
# ---------------------------------------------------------------------------
class _FakeApp:
"""Minimal stand-in for the Flask app proxy used by the sidecar module."""
def __init__(self, config: dict[str, Any] | None = None) -> None:
self.config = config or {}
@mock.patch(
"superset.charts.data.query_context_sidecar.app",
new=_FakeApp({}),
)
def test_maybe_generate_noop_when_no_sidecar_url() -> None:
model = mock.MagicMock()
maybe_generate_query_context(model, '{"viz_type": "pie"}')
@mock.patch(
"superset.charts.data.query_context_sidecar.app",
new=_FakeApp({"QUERY_CONTEXT_SIDECAR_URL": "http://sidecar.internal"}),
)
def test_maybe_generate_noop_when_params_json_is_none() -> None:
model = mock.MagicMock()
maybe_generate_query_context(model, None)
@mock.patch(
"superset.charts.data.query_context_sidecar.fetch_query_context_from_sidecar"
)
@mock.patch(
"superset.charts.data.query_context_sidecar.app",
new=_FakeApp(
{
"QUERY_CONTEXT_SIDECAR_URL": "http://sidecar.internal",
"QUERY_CONTEXT_SIDECAR_TIMEOUT": 10,
}
),
)
def test_maybe_generate_sets_query_context_on_success(
mock_fetch: mock.MagicMock,
) -> None:
mock_fetch.return_value = {"datasource": {"id": 1}, "queries": []}
model = mock.MagicMock()
maybe_generate_query_context(model, '{"viz_type": "pie"}')
mock_fetch.assert_called_once_with(
sidecar_url="http://sidecar.internal",
form_data={"viz_type": "pie"},
timeout=10,
)
assert model.query_context is not None
@mock.patch(
"superset.charts.data.query_context_sidecar.fetch_query_context_from_sidecar"
)
@mock.patch(
"superset.charts.data.query_context_sidecar.app",
new=_FakeApp(
{
"QUERY_CONTEXT_SIDECAR_URL": "http://sidecar.internal",
"QUERY_CONTEXT_SIDECAR_TIMEOUT": 10,
}
),
)
def test_maybe_generate_logs_on_sidecar_error(
mock_fetch: mock.MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
mock_fetch.side_effect = QueryContextSidecarError("boom")
model = mock.MagicMock()
model.id = 42
with caplog.at_level("WARNING"):
maybe_generate_query_context(model, '{"viz_type": "pie"}')
assert "Failed to generate query context" in caplog.text
@mock.patch(
"superset.charts.data.query_context_sidecar.app",
new=_FakeApp({"QUERY_CONTEXT_SIDECAR_URL": "http://sidecar.internal"}),
)
def test_maybe_generate_logs_on_invalid_json(
caplog: pytest.LogCaptureFixture,
) -> None:
model = mock.MagicMock()
with caplog.at_level("WARNING"):
maybe_generate_query_context(model, "not-valid-json{{{")
assert "Could not parse chart params" in caplog.text
@mock.patch(
"superset.charts.data.query_context_sidecar.fetch_query_context_from_sidecar"
)
@mock.patch(
"superset.charts.data.query_context_sidecar.app",
new=_FakeApp(
{
"QUERY_CONTEXT_SIDECAR_URL": "http://sidecar.internal",
"QUERY_CONTEXT_SIDECAR_TIMEOUT": 10,
}
),
)
def test_maybe_generate_logs_on_unexpected_error(
mock_fetch: mock.MagicMock,
caplog: pytest.LogCaptureFixture,
) -> None:
mock_fetch.side_effect = RuntimeError("unexpected")
model = mock.MagicMock()
model.id = 99
with caplog.at_level("WARNING"):
maybe_generate_query_context(model, '{"viz_type": "pie"}')
assert "Unexpected error" in caplog.text

View File

@@ -16,6 +16,8 @@
# under the License.
import copy
from typing import Any, cast
from uuid import UUID
import yaml
from pytest_mock import MockerFixture
@@ -153,8 +155,10 @@ def test_import_assets_imports_tags(mocker: MockerFixture, session: Session) ->
ImportAssetsCommand._import(configs, contents=contents)
chart_uuids = {config["uuid"] for config in charts_with_tags.values()}
imported_charts = db.session.query(Slice).filter(Slice.uuid.in_(chart_uuids)).all()
chart_uuids = {UUID(str(config["uuid"])) for config in charts_with_tags.values()}
imported_charts = (
db.session.query(Slice).filter(cast(Any, Slice.uuid).in_(chart_uuids)).all()
)
assert len(imported_charts) == len(chart_uuids)
for chart in imported_charts:
assocs = (
@@ -165,9 +169,13 @@ def test_import_assets_imports_tags(mocker: MockerFixture, session: Session) ->
assert len(assocs) == 1
assert assocs[0].tag.name == "chart_tag"
dashboard_uuids = {config["uuid"] for config in dashboards_with_tags.values()}
dashboard_uuids = {
UUID(str(config["uuid"])) for config in dashboards_with_tags.values()
}
imported_dashboards = (
db.session.query(Dashboard).filter(Dashboard.uuid.in_(dashboard_uuids)).all()
db.session.query(Dashboard)
.filter(cast(Any, Dashboard.uuid).in_(dashboard_uuids))
.all()
)
assert len(imported_dashboards) == len(dashboard_uuids)
for dashboard in imported_dashboards:

View File

@@ -24,6 +24,7 @@ import pandas as pd
import pytest
from pytest_mock import MockerFixture
from requests.exceptions import HTTPError
from shillelagh.exceptions import UnauthenticatedError
from sqlalchemy.engine.url import make_url
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@@ -789,6 +790,36 @@ def test_needs_oauth2_with_other_error(mocker: MockerFixture) -> None:
assert GSheetsEngineSpec.needs_oauth2(ex) is False
def test_needs_oauth2_with_shillelagh_unauthenticated_error(
mocker: MockerFixture,
) -> None:
"""
Test that needs_oauth2 returns True when UnauthenticatedError is raised.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
g = mocker.patch("superset.db_engine_specs.gsheets.g")
g.user = mocker.MagicMock()
ex = UnauthenticatedError("Token has been revoked")
assert GSheetsEngineSpec.needs_oauth2(ex) is True
def test_needs_oauth2_with_unrelated_exception_type(
mocker: MockerFixture,
) -> None:
"""
Test that an unrelated exception type (with no matching message) returns
False.
"""
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
g = mocker.patch("superset.db_engine_specs.gsheets.g")
g.user = mocker.MagicMock()
assert GSheetsEngineSpec.needs_oauth2(ValueError("unrelated")) is False
def test_get_oauth2_fresh_token_success(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,

View File

@@ -320,66 +320,13 @@ class TestChartDataModelMetadataPrivacy:
assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE
class TestListChartsCreatedByMe:
"""Tests for the created_by_me flag on ListChartsRequest."""
def test_created_by_me_default_is_false(self):
request = ListChartsRequest()
assert request.created_by_me is False
def test_created_by_me_true_accepted(self):
request = ListChartsRequest(created_by_me=True)
assert request.created_by_me is True
def test_created_by_me_combined_with_filters(self):
request = ListChartsRequest(
created_by_me=True,
filters=[ChartFilter(col="slice_name", opr="sw", value="My")],
)
assert request.created_by_me is True
assert len(request.filters) == 1
def test_created_by_me_with_search_raises(self):
from pydantic import ValidationError
with pytest.raises(ValidationError, match="created_by_me"):
ListChartsRequest(created_by_me=True, search="My charts")
def test_chart_filter_rejects_created_by_fk(self):
"""created_by_fk is not a public filter column; use created_by_me instead."""
from pydantic import ValidationError
with pytest.raises(ValidationError):
ChartFilter(col="created_by_fk", opr="eq", value=1)
class TestListChartsOwnedByMe:
"""Tests for the owned_by_me flag on ListChartsRequest."""
def test_owned_by_me_default_is_false(self):
request = ListChartsRequest()
assert request.owned_by_me is False
def test_owned_by_me_true_accepted(self):
request = ListChartsRequest(owned_by_me=True)
assert request.owned_by_me is True
def test_owned_by_me_combined_with_filters(self):
request = ListChartsRequest(
owned_by_me=True,
filters=[ChartFilter(col="slice_name", opr="sw", value="My")],
)
assert request.owned_by_me is True
assert len(request.filters) == 1
def test_owned_by_me_with_search_raises(self):
from pydantic import ValidationError
with pytest.raises(ValidationError, match="owned_by_me"):
ListChartsRequest(owned_by_me=True, search="My charts")
def test_owned_by_me_and_created_by_me_allowed(self):
"""Both flags together are valid (OR logic — creator or owner)."""
request = ListChartsRequest(owned_by_me=True, created_by_me=True)
assert request.owned_by_me is True
assert request.created_by_me is True
@patch("superset.daos.chart.ChartDAO.list")
@pytest.mark.asyncio
async def test_list_charts_no_arguments(mock_list, mcp_server):
"""Regression test: list_charts must accept zero arguments without raising
pydantic_core.ValidationError: Missing required argument: request."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool("list_charts", {})
data = json.loads(result.content[0].text)
assert "charts" in data

View File

@@ -236,6 +236,7 @@ class TestNormalizeColumnNames:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.x is not None
assert normalized.x.name == "OrderDate"
assert normalized.y[0].name == "Sales"
assert normalized.filters is not None
@@ -278,6 +279,7 @@ class TestNormalizeColumnNames:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=999)
# Should return original config unchanged
assert normalized.x is not None
assert normalized.x.name == "orderdate"
assert normalized.y[0].name == "sales"
@@ -318,11 +320,13 @@ class TestTimeSeriesFilterPromptFix:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
# After normalization, x.name should match the filter column exactly
assert normalized.x is not None
assert normalized.x.name == "OrderDate"
assert normalized.filters is not None
assert normalized.filters[0].column == "OrderDate"
# This equality is what the frontend checks - now they match!
assert normalized.x is not None
assert normalized.x.name == normalized.filters[0].column
@@ -394,6 +398,7 @@ class TestNormalizeUppercaseDataset:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=24)
assert normalized.x is not None
assert normalized.x.name == "ds"
assert normalized.y[0].name == "DISTANCE"
assert normalized.group_by is not None
@@ -417,6 +422,7 @@ class TestNormalizeUppercaseDataset:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=24)
assert normalized.x is not None
assert normalized.x.name == "ds"
assert normalized.y[0].name == "DEPARTURE_DELAY"
@@ -459,6 +465,7 @@ class TestNormalizeEdgeCases:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.x is not None
assert normalized.x.name == "OrderDate"
assert normalized.y[0].name == "Sales"
assert normalized.filters is None
@@ -480,6 +487,7 @@ class TestNormalizeEdgeCases:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.x is not None
assert normalized.x.name == "OrderDate"
assert normalized.filters is not None
assert len(normalized.filters) == 0
@@ -500,6 +508,7 @@ class TestNormalizeEdgeCases:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.x is not None
assert normalized.x.name == "OrderDate"
assert normalized.group_by is None
@@ -527,6 +536,7 @@ class TestNormalizeEdgeCases:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.x is not None
assert normalized.x.name == "OrderDate"
assert normalized.y[0].name == "Sales"
assert normalized.y[1].name == "quantity_ordered"
@@ -554,6 +564,8 @@ class TestNormalizeEdgeCases:
first = DatasetValidator.normalize_column_names(config, dataset_id=18)
second = DatasetValidator.normalize_column_names(first, dataset_id=18)
assert first.x is not None
assert second.x is not None
assert first.x.name == second.x.name == "OrderDate"
assert first.y[0].name == second.y[0].name == "Sales"
assert first.filters is not None
@@ -636,6 +648,7 @@ class TestNormalizeXAxisFilterConsistency:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=18)
assert normalized.filters is not None
assert normalized.x is not None
assert normalized.x.name == normalized.filters[0].column == "OrderDate"
@patch.object(DatasetValidator, "_get_dataset_context")
@@ -656,6 +669,7 @@ class TestNormalizeXAxisFilterConsistency:
normalized = DatasetValidator.normalize_column_names(config, dataset_id=24)
assert normalized.filters is not None
assert normalized.x is not None
assert normalized.x.name == normalized.filters[0].column == "ds"
@patch.object(DatasetValidator, "_get_dataset_context")

View File

@@ -30,7 +30,6 @@ from flask import g
from superset.mcp_service.app import mcp
from superset.mcp_service.dashboard.schemas import (
DashboardFilter,
ListDashboardsRequest,
)
from superset.mcp_service.dashboard.tool.get_dashboard_info import (
@@ -1355,66 +1354,13 @@ class TestDashboardSortableColumns:
assert col in list_dashboards.__doc__
class TestListDashboardsCreatedByMe:
"""Tests for the created_by_me flag on ListDashboardsRequest."""
def test_created_by_me_default_is_false(self):
request = ListDashboardsRequest()
assert request.created_by_me is False
def test_created_by_me_true_accepted(self):
request = ListDashboardsRequest(created_by_me=True)
assert request.created_by_me is True
def test_created_by_me_combined_with_filters(self):
request = ListDashboardsRequest(
created_by_me=True,
filters=[DashboardFilter(col="published", opr="eq", value=True)],
)
assert request.created_by_me is True
assert len(request.filters) == 1
def test_created_by_me_with_search_raises(self):
from pydantic import ValidationError
with pytest.raises(ValidationError, match="created_by_me"):
ListDashboardsRequest(created_by_me=True, search="My dashboards")
def test_dashboard_filter_rejects_created_by_fk(self):
"""created_by_fk is not a public filter column; use created_by_me instead."""
from pydantic import ValidationError
with pytest.raises(ValidationError):
DashboardFilter(col="created_by_fk", opr="eq", value=1)
class TestListDashboardsOwnedByMe:
"""Tests for the owned_by_me flag on ListDashboardsRequest."""
def test_owned_by_me_default_is_false(self):
request = ListDashboardsRequest()
assert request.owned_by_me is False
def test_owned_by_me_true_accepted(self):
request = ListDashboardsRequest(owned_by_me=True)
assert request.owned_by_me is True
def test_owned_by_me_combined_with_filters(self):
request = ListDashboardsRequest(
owned_by_me=True,
filters=[DashboardFilter(col="published", opr="eq", value=True)],
)
assert request.owned_by_me is True
assert len(request.filters) == 1
def test_owned_by_me_with_search_raises(self):
from pydantic import ValidationError
with pytest.raises(ValidationError, match="owned_by_me"):
ListDashboardsRequest(owned_by_me=True, search="My dashboards")
def test_owned_by_me_and_created_by_me_allowed(self):
"""Both flags together are valid (OR logic — creator or owner)."""
request = ListDashboardsRequest(owned_by_me=True, created_by_me=True)
assert request.owned_by_me is True
assert request.created_by_me is True
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_no_arguments(mock_list, mcp_server):
"""Regression test: list_dashboards must accept zero arguments without raising
pydantic_core.ValidationError: Missing required argument: request."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool("list_dashboards", {})
data = json.loads(result.content[0].text)
assert "dashboards" in data

View File

@@ -0,0 +1,831 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the query_dataset MCP tool."""
from __future__ import annotations
import importlib
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client, FastMCP
from superset.mcp_service.app import mcp
from superset.utils import json
query_dataset_module = importlib.import_module(
"superset.mcp_service.dataset.tool.query_dataset"
)
@pytest.fixture
def mcp_server() -> FastMCP:
return mcp
@pytest.fixture(autouse=True)
def mock_auth() -> Generator[MagicMock, None, None]:
"""Mock authentication and metadata access for all tests."""
with (
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
patch.object(
query_dataset_module,
"user_can_view_data_model_metadata",
return_value=True,
),
):
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _make_column(name: str, is_dttm: bool = False) -> MagicMock:
"""Build a mock SqlaTable column with the given name and datetime flag."""
col = MagicMock()
col.column_name = name
col.is_dttm = is_dttm
col.verbose_name = None
col.type = "VARCHAR"
col.groupby = True
col.filterable = True
col.description = None
return col
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
"""Build a mock SqlMetric with the given name and SQL expression."""
metric = MagicMock()
metric.metric_name = name
metric.verbose_name = None
metric.expression = expression
metric.description = None
metric.d3format = None
return metric
def _make_dataset(
dataset_id: int = 1,
table_name: str = "orders",
columns: list[Any] | None = None,
metrics: list[Any] | None = None,
main_dttm_col: str | None = None,
) -> MagicMock:
"""Build a mock SqlaTable dataset with default columns and metrics."""
ds = MagicMock()
ds.id = dataset_id
ds.table_name = table_name
ds.uuid = f"test-uuid-{dataset_id}"
ds.main_dttm_col = main_dttm_col
ds.database = MagicMock()
ds.database.database_name = "examples"
ds.columns = columns or [
_make_column("category"),
_make_column("region"),
_make_column("order_date", is_dttm=True),
]
ds.metrics = metrics or [
_make_metric("count", "COUNT(*)"),
_make_metric("total_revenue", "SUM(revenue)"),
]
return ds
def _mock_command_result(
data: list[dict[str, Any]] | None = None,
colnames: list[str] | None = None,
) -> dict[str, Any]:
"""Build the result dict that ChartDataCommand.run() returns."""
data = data or [
{"category": "Electronics", "count": 42},
{"category": "Clothing", "count": 17},
]
colnames = colnames or ["category", "count"]
return {
"queries": [
{
"data": data,
"colnames": colnames,
"rowcount": len(data),
"cache_key": "abc123",
"is_cached": False,
"cached_dttm": None,
"cache_timeout": 300,
}
]
}
@pytest.mark.asyncio
async def test_query_dataset_success(mcp_server: FastMCP) -> None:
"""Happy path: metrics + columns returns data."""
dataset = _make_dataset()
result_data = _mock_command_result()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"columns": ["category"],
}
},
)
data = json.loads(result.content[0].text)
assert data["dataset_id"] == 1
assert data["dataset_name"] == "orders"
assert data["row_count"] == 2
assert len(data["data"]) == 2
assert data["data"][0]["category"] == "Electronics"
@pytest.mark.asyncio
async def test_query_dataset_not_found(mcp_server: FastMCP) -> None:
"""Dataset ID that doesn't exist returns error."""
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=None,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 999,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "NotFound"
assert "999" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_invalid_metric(mcp_server: FastMCP) -> None:
"""Unknown metric name returns validation error with suggestions."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["countt"], # typo
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "countt" in data["error"]
# Should suggest "count" as a close match
assert "count" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_invalid_column(mcp_server: FastMCP) -> None:
"""Unknown column name returns validation error."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"columns": ["nonexistent_col"],
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "nonexistent_col" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_no_metrics_no_columns(mcp_server: FastMCP) -> None:
"""Providing neither metrics nor columns raises validation error."""
from fastmcp.exceptions import ToolError
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="metrics.*columns"):
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": [],
"columns": [],
}
},
)
@pytest.mark.asyncio
async def test_query_dataset_with_time_range(mcp_server: FastMCP) -> None:
"""time_range is converted to TEMPORAL_RANGE filter + granularity."""
dataset = _make_dataset(main_dttm_col="order_date")
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 7 days",
}
},
)
assert len(captured_queries) == 1
query_dict = captured_queries[0]
# Should have TEMPORAL_RANGE filter
temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"]
assert len(temporal_filters) == 1
assert temporal_filters[0]["col"] == "order_date"
assert temporal_filters[0]["val"] == "Last 7 days"
# Should set granularity
assert query_dict["granularity"] == "order_date"
# applied_filters in response must include the synthesized TEMPORAL_RANGE filter
data = json.loads(result.content[0].text)
resp_filters = data["applied_filters"]
temporal_resp = [f for f in resp_filters if f["op"] == "TEMPORAL_RANGE"]
assert len(temporal_resp) == 1
assert temporal_resp[0]["col"] == "order_date"
assert temporal_resp[0]["val"] == "Last 7 days"
@pytest.mark.asyncio
async def test_query_dataset_time_range_no_temporal_column(mcp_server: FastMCP) -> None:
"""time_range without a temporal column returns error."""
dataset = _make_dataset(main_dttm_col=None)
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 7 days",
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "temporal column" in data["error"].lower()
@pytest.mark.asyncio
async def test_query_dataset_with_filters(mcp_server: FastMCP) -> None:
"""User-provided filters are passed through to the query."""
dataset = _make_dataset()
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"filters": [
{"col": "category", "op": "==", "val": "Electronics"}
],
}
},
)
assert len(captured_queries) == 1
filters = captured_queries[0]["filters"]
assert len(filters) == 1
assert filters[0]["col"] == "category"
assert filters[0]["op"] == "=="
assert filters[0]["val"] == "Electronics"
@pytest.mark.asyncio
async def test_query_dataset_empty_results(mcp_server: FastMCP) -> None:
"""Query that returns no data gives a response with row_count=0."""
dataset = _make_dataset()
empty_result = {
"queries": [
{
"data": [],
"colnames": [],
"rowcount": 0,
"is_cached": False,
"cached_dttm": None,
"cache_timeout": 300,
}
]
}
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=empty_result,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["row_count"] == 0
assert data["data"] == []
assert "no data" in data["summary"].lower()
@pytest.mark.asyncio
async def test_query_dataset_by_uuid(mcp_server: FastMCP) -> None:
"""UUID-based lookup works."""
dataset = _make_dataset()
result_data = _mock_command_result()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
) as mock_resolve,
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": "a1b2c3d4-5678-90ab-cdef-1234567890ab",
"metrics": ["count"],
}
},
)
# Verify the resolve function was called with the UUID
mock_resolve.assert_called_once()
call_args = mock_resolve.call_args
assert call_args[0][0] == "a1b2c3d4-5678-90ab-cdef-1234567890ab"
data = json.loads(result.content[0].text)
assert data["dataset_id"] == 1
@pytest.mark.asyncio
async def test_query_dataset_permission_denied(mcp_server: FastMCP) -> None:
"""Permission denied from ChartDataCommand.validate() returns error."""
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
dataset = _make_dataset()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
side_effect=SupersetSecurityException(
SupersetError(
message="Access denied",
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.WARNING,
)
),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "QueryError"
@pytest.mark.asyncio
async def test_query_dataset_order_by_valid(mcp_server: FastMCP) -> None:
"""order_by with valid column/metric names passes through."""
dataset = _make_dataset()
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"columns": ["category"],
"order_by": ["count"],
"order_desc": True,
}
},
)
assert len(captured_queries) == 1
orderby = captured_queries[0].get("orderby", [])
assert len(orderby) == 1
assert orderby[0][0] == "count"
# order_desc=True -> ascending=False
assert orderby[0][1] is False
@pytest.mark.asyncio
async def test_query_dataset_order_by_invalid(mcp_server: FastMCP) -> None:
"""order_by with an unknown name returns validation error."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"order_by": ["nonexistent"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "nonexistent" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_time_column_override(mcp_server: FastMCP) -> None:
"""Explicit time_column overrides dataset main_dttm_col."""
dataset = _make_dataset(main_dttm_col="order_date")
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 30 days",
"time_column": "order_date",
}
},
)
assert len(captured_queries) == 1
query_dict = captured_queries[0]
assert query_dict["granularity"] == "order_date"
temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"]
assert temporal_filters[0]["col"] == "order_date"
@pytest.mark.asyncio
async def test_query_dataset_non_dttm_time_column_warns(mcp_server: FastMCP) -> None:
"""Using a non-datetime column for time_range produces a warning."""
dataset = _make_dataset(main_dttm_col=None)
result_data = _mock_command_result()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 7 days",
"time_column": "category",
}
},
)
data = json.loads(result.content[0].text)
assert len(data["warnings"]) > 0
assert "not marked as a datetime" in data["warnings"][0]
@pytest.mark.asyncio
async def test_query_dataset_invalid_filter_column(mcp_server: FastMCP) -> None:
"""Filter on a column that doesn't exist returns validation error."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"filters": [
{
"col": "nonexistent",
"op": "==",
"val": "test",
}
],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "nonexistent" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_metadata_access_denied_no_suggestions(
mcp_server: FastMCP,
) -> None:
"""Users without data-model metadata access cannot probe column/metric names.
The privacy gate must fire before the validation step that returns close-match
suggestions, so restricted users cannot enumerate schema details via typos.
"""
dataset = _make_dataset()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch.object(
query_dataset_module,
"user_can_view_data_model_metadata",
return_value=False,
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
# Typo that would normally trigger close-match suggestions
"metrics": ["countt"],
}
},
)
data = json.loads(result.content[0].text)
# Must be denied before returning any schema suggestions
assert data["error_type"] == "DataModelMetadataRestricted"
# Must NOT contain column/metric name suggestions
assert "countt" not in data.get("error", "")
assert "count" not in data.get("error", "")
@pytest.mark.asyncio
async def test_query_dataset_metadata_access_denied_nonexistent_dataset(
mcp_server: FastMCP,
) -> None:
"""Metadata-restricted users must not be able to probe dataset existence.
The privacy gate fires before the DAO lookup, so a restricted caller
always receives DataModelMetadataRestricted — never NotFound — regardless
of whether the requested dataset ID exists.
"""
with patch.object(
query_dataset_module,
"user_can_view_data_model_metadata",
return_value=False,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
# Use a dataset_id that does not exist
"dataset_id": 999999,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
# Must receive restricted error, not a NotFound that leaks existence
assert data["error_type"] == "DataModelMetadataRestricted"
assert data["error_type"] != "NotFound"

View File

@@ -372,6 +372,43 @@ def test_mcp_auth_hook_preserves_g_user_in_request_context(app) -> None:
assert result == "middleware_user"
def test_mcp_auth_hook_removes_stale_db_session_in_sync_wrapper(app) -> None:
"""sync_wrapper calls db.session.remove() BEFORE get_user_from_request().
Thread pool workers reuse threads across requests; db.session is
thread-local and may be bound to a different tenant's DB engine from a
prior request. Removing it before user lookup ensures a fresh session is
created for the current request.
The ordering is critical: if remove() were called after user lookup,
the stale session binding would already have caused a mismatch error.
"""
fresh_user = _make_mock_user("fresh")
def dummy_tool():
"""Dummy tool."""
return g.user.username
wrapped = mcp_auth_hook(dummy_tool)
with app.test_request_context():
g.user = fresh_user
with patch("superset.extensions.db") as mock_db:
def _assert_remove_already_called() -> MagicMock:
"""Verify remove() was called before user resolution runs."""
mock_db.session.remove.assert_called_once_with()
return fresh_user
with patch(
"superset.mcp_service.auth.get_user_from_request",
side_effect=_assert_remove_already_called,
):
result = wrapped()
assert result == "fresh"
# -- default_user_resolver --

View File

@@ -774,7 +774,7 @@ def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
database.db_engine_spec.oauth2_exception = OAuth2Error
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
@@ -805,7 +805,7 @@ def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
database.db_engine_spec.oauth2_exception = OAuth2Error
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
"OAuth2 required"
@@ -838,7 +838,7 @@ def test_raw_connection_oauth_execute(mocker: MockerFixture) -> None:
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
database.db_engine_spec.oauth2_exception = OAuth2Error
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection().cursor().execute.side_effect = (
OAuth2Error("OAuth2 required")

View File

@@ -95,7 +95,7 @@ def test_cache_key_changes_for_new_query_object_same_params():
assert query_object2.cache_key() == cache_key1
@patch("superset.common.query_object.feature_flag_manager")
@patch("superset.utils.cache_keys.feature_flag_manager")
def test_cache_key_cache_query_by_user_on_no_datasource(feature_flag_mock):
"""
When CACHE_QUERY_BY_USER flag is on and there is no datasource,
@@ -112,7 +112,7 @@ def test_cache_key_cache_query_by_user_on_no_datasource(feature_flag_mock):
assert query_object.cache_key() == cache_key
@patch("superset.common.query_object.feature_flag_manager")
@patch("superset.utils.cache_keys.feature_flag_manager")
@patch("superset.common.query_object.logger")
def test_cache_key_cache_query_by_user_on_no_user(logger_mock, feature_flag_mock):
"""
@@ -140,16 +140,13 @@ def test_cache_key_cache_query_by_user_on_no_user(logger_mock, feature_flag_mock
logger_mock.debug.assert_called()
@patch("superset.common.query_object.feature_flag_manager")
@patch("superset.common.query_object.logger")
@patch("superset.utils.cache_keys.feature_flag_manager")
@patch("superset.utils.cache_keys.logger")
def test_cache_key_cache_query_by_user_on_with_user(logger_mock, feature_flag_mock):
"""
When the same user is requesting a cache key with CACHE_QUERY_BY_USER
flag on, the key will be the same
"""
# Configure logger to enable DEBUG level for isEnabledFor check
logger_mock.isEnabledFor.return_value = True
datasource = SqlaTable(
table_name="test_table",
columns=[],
@@ -167,17 +164,17 @@ def test_cache_key_cache_query_by_user_on_with_user(logger_mock, feature_flag_mo
cache_key1 = query_object.cache_key()
assert query_object.cache_key() == cache_key1
# Should have both impersonation and cache key generation logs
# Should have impersonation log emitted by the cache_keys helper
logger_mock.debug.assert_has_calls(
[
call("Adding impersonation key to QueryObject cache dict: %s", "test_user"),
call("Adding impersonation key to cache dict: %s", "test_user"),
],
any_order=True,
)
@patch("superset.common.query_object.feature_flag_manager")
@patch("superset.common.query_object.logger")
@patch("superset.utils.cache_keys.feature_flag_manager")
@patch("superset.utils.cache_keys.logger")
def test_cache_key_cache_query_by_user_on_with_different_user(
logger_mock, feature_flag_mock
):
@@ -185,9 +182,6 @@ def test_cache_key_cache_query_by_user_on_with_different_user(
When two different users are requesting a cache key with CACHE_QUERY_BY_USER
flag on, the key will be different
"""
# Configure logger to enable DEBUG level for isEnabledFor check
logger_mock.isEnabledFor.return_value = True
datasource = SqlaTable(
table_name="test_table",
columns=[],
@@ -209,21 +203,17 @@ def test_cache_key_cache_query_by_user_on_with_different_user(
assert cache_key1 != cache_key2
# Should have both impersonation and cache key generation logs (any order)
# Should have impersonation logs emitted by the cache_keys helper
logger_mock.debug.assert_has_calls(
[
call(
"Adding impersonation key to QueryObject cache dict: %s", "test_user1"
),
call(
"Adding impersonation key to QueryObject cache dict: %s", "test_user2"
),
call("Adding impersonation key to cache dict: %s", "test_user1"),
call("Adding impersonation key to cache dict: %s", "test_user2"),
],
any_order=True,
)
@patch("superset.common.query_object.feature_flag_manager")
@patch("superset.utils.cache_keys.feature_flag_manager")
@patch("superset.common.query_object.logger")
def test_cache_key_cache_impersonation_on_no_user(logger_mock, feature_flag_mock):
"""
@@ -251,7 +241,7 @@ def test_cache_key_cache_impersonation_on_no_user(logger_mock, feature_flag_mock
logger_mock.debug.assert_called()
@patch("superset.common.query_object.feature_flag_manager")
@patch("superset.utils.cache_keys.feature_flag_manager")
@patch("superset.common.query_object.logger")
def test_cache_key_cache_impersonation_on_with_user(logger_mock, feature_flag_mock):
"""
@@ -290,7 +280,7 @@ def test_cache_key_cache_impersonation_on_with_user(logger_mock, feature_flag_mo
assert len(impersonation_calls) == 0
@patch("superset.common.query_object.feature_flag_manager")
@patch("superset.utils.cache_keys.feature_flag_manager")
@patch("superset.common.query_object.logger")
def test_cache_key_cache_impersonation_on_with_different_user(
logger_mock, feature_flag_mock
@@ -335,8 +325,8 @@ def test_cache_key_cache_impersonation_on_with_different_user(
assert len(impersonation_calls) == 0
@patch("superset.common.query_object.feature_flag_manager")
@patch("superset.common.query_object.logger")
@patch("superset.utils.cache_keys.feature_flag_manager")
@patch("superset.utils.cache_keys.logger")
def test_cache_key_cache_impersonation_on_with_different_user_and_db_impersonation(
logger_mock,
feature_flag_mock,
@@ -346,9 +336,6 @@ def test_cache_key_cache_impersonation_on_with_different_user_and_db_impersonati
flag on, and cache_impersonation is enabled on the database,
the keys will be different
"""
# Configure logger to enable DEBUG level for isEnabledFor check
logger_mock.isEnabledFor.return_value = True
datasource = SqlaTable(
table_name="test_table",
columns=[],
@@ -374,15 +361,11 @@ def test_cache_key_cache_impersonation_on_with_different_user_and_db_impersonati
assert cache_key1 != cache_key2
# Should have both impersonation and cache key generation logs (any order)
# Should have impersonation logs emitted by the cache_keys helper
logger_mock.debug.assert_has_calls(
[
call(
"Adding impersonation key to QueryObject cache dict: %s", "test_user1"
),
call(
"Adding impersonation key to QueryObject cache dict: %s", "test_user2"
),
call("Adding impersonation key to cache dict: %s", "test_user1"),
call("Adding impersonation key to cache dict: %s", "test_user2"),
],
any_order=True,
)

View File

@@ -220,7 +220,7 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
database.db_engine_spec.oauth2_exception = OAuth2Error
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
"OAuth2 required"

View File

@@ -0,0 +1,111 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Behavioral tests for ``viz.BaseViz.cache_key`` covering per-user cache-key
inclusion.
"""
from typing import Any
from unittest.mock import patch
from flask_appbuilder.security.sqla.models import User
from superset import viz
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.utils.core import override_user
QUERY_OBJ: dict[str, Any] = {"row_limit": 100, "from_dttm": None, "to_dttm": None}
def _viz_for(database: Database) -> viz.BaseViz:
datasource = SqlaTable(
table_name="t",
columns=[],
metrics=[],
main_dttm_col=None,
database=database,
)
return viz.BaseViz(datasource=datasource, form_data={"viz_type": "table"})
def test_no_per_user_opt_in_keys_match_across_users():
"""
Without any per-user caching opt-in, two different users on the same
database/query must produce the *same* cache key (regression guard — we
must not accidentally make every cache key per-user).
"""
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
obj = _viz_for(database)
with override_user(User(username="alice")):
key_a = obj.cache_key(QUERY_OBJ)
with override_user(User(username="bob")):
key_b = obj.cache_key(QUERY_OBJ)
assert key_a == key_b
def test_per_user_caching_in_extra_yields_distinct_keys_per_user():
"""
With ``per_user_caching: true`` set on the database, two different users
must produce *different* cache keys for the same query.
"""
database = Database(
database_name="d",
sqlalchemy_uri="sqlite://",
extra='{"per_user_caching": true}',
)
obj = _viz_for(database)
with override_user(User(username="alice")):
key_a = obj.cache_key(QUERY_OBJ)
with override_user(User(username="bob")):
key_b = obj.cache_key(QUERY_OBJ)
assert key_a != key_b
def test_same_user_same_query_idempotent():
database = Database(
database_name="d",
sqlalchemy_uri="sqlite://",
extra='{"per_user_caching": true}',
)
obj = _viz_for(database)
with override_user(User(username="alice")):
assert obj.cache_key(QUERY_OBJ) == obj.cache_key(QUERY_OBJ)
@patch("superset.utils.cache_keys.feature_flag_manager")
def test_cache_query_by_user_flag_yields_distinct_keys(feature_flag_mock):
"""
Global ``CACHE_QUERY_BY_USER`` flag also reaches the legacy viz path.
"""
feature_flag_mock.is_feature_enabled.side_effect = (
lambda feature=None: feature == "CACHE_QUERY_BY_USER"
)
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
obj = _viz_for(database)
with override_user(User(username="alice")):
key_a = obj.cache_key(QUERY_OBJ)
with override_user(User(username="bob")):
key_b = obj.cache_key(QUERY_OBJ)
assert key_a != key_b

View File

@@ -0,0 +1,107 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from unittest.mock import patch
from flask_appbuilder.security.sqla.models import User
from superset.models.core import Database
from superset.utils.cache_keys import add_impersonation_cache_key_if_needed
from superset.utils.core import override_user
def _flag(name: str):
"""Build a feature-flag side_effect that returns True only for ``name``."""
def side_effect(feature=None):
return feature == name
return side_effect
def _run(database: Database) -> dict[str, Any]:
"""Run the helper against a fresh dict and return that dict."""
cache_dict: dict[str, Any] = {}
add_impersonation_cache_key_if_needed(database, cache_dict)
return cache_dict
def test_no_per_user_caching_yields_no_key():
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
with override_user(User(username="u")):
assert "impersonation_key" not in _run(database)
@patch("superset.utils.cache_keys.feature_flag_manager")
def test_cache_query_by_user_adds_username(feature_flag_mock):
feature_flag_mock.is_feature_enabled.side_effect = _flag("CACHE_QUERY_BY_USER")
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
with override_user(User(username="alice")):
assert _run(database)["impersonation_key"] == "alice"
@patch("superset.utils.cache_keys.feature_flag_manager")
def test_cache_query_by_user_distinct_per_user(feature_flag_mock):
feature_flag_mock.is_feature_enabled.side_effect = _flag("CACHE_QUERY_BY_USER")
database = Database(database_name="d", sqlalchemy_uri="sqlite://")
with override_user(User(username="alice")):
key_a = _run(database)["impersonation_key"]
with override_user(User(username="bob")):
key_b = _run(database)["impersonation_key"]
assert key_a != key_b
@patch("superset.utils.cache_keys.feature_flag_manager")
def test_cache_impersonation_requires_database_flag(feature_flag_mock):
"""
CACHE_IMPERSONATION alone is not enough; ``database.impersonate_user`` must
also be set on the database for the per-user key to apply.
"""
feature_flag_mock.is_feature_enabled.side_effect = _flag("CACHE_IMPERSONATION")
db_no_impersonation = Database(database_name="d", sqlalchemy_uri="sqlite://")
db_with_impersonation = Database(
database_name="d", sqlalchemy_uri="sqlite://", impersonate_user=True
)
with override_user(User(username="alice")):
assert "impersonation_key" not in _run(db_no_impersonation)
assert _run(db_with_impersonation)["impersonation_key"] == "alice"
def test_per_user_caching_in_extra_json_enables_key():
database = Database(
database_name="d",
sqlalchemy_uri="sqlite://",
extra='{"per_user_caching": true}',
)
with override_user(User(username="alice")):
assert _run(database)["impersonation_key"] == "alice"
def test_no_user_yields_no_key(app_context): # noqa: ARG001
"""
With no logged-in user, the engine spec returns None even when per-user
caching is enabled — there's no identity to key on.
"""
database = Database(
database_name="d",
sqlalchemy_uri="sqlite://",
extra='{"per_user_caching": true}',
)
# No override_user — g.user is unset
assert "impersonation_key" not in _run(database)