mirror of
https://github.com/apache/superset.git
synced 2026-05-19 14:55:13 +00:00
Compare commits
16 Commits
enxdev/fix
...
enxdev/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d40a5cad5d | ||
|
|
38546d7a3d | ||
|
|
6e5dfa0dd4 | ||
|
|
70419e9d8f | ||
|
|
34281f54a6 | ||
|
|
53d5c41a72 | ||
|
|
453f49ce33 | ||
|
|
b66c104fde | ||
|
|
61b77fa35d | ||
|
|
0da0767780 | ||
|
|
e2ff2d5d41 | ||
|
|
6a6be4c385 | ||
|
|
cf831388d8 | ||
|
|
684a66aee6 | ||
|
|
80a200820c | ||
|
|
f47300102c |
4
.github/workflows/codeql-analysis.yml
vendored
4
.github/workflows/codeql-analysis.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@68bde559dea0fdcac2102bfdf6230c5f70eb485e # v4
|
||||
uses: github/codeql-action/init@9e0d7b8d25671d64c341c19c0152d693099fb5ba # v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
@@ -53,6 +53,6 @@ jobs:
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
if: steps.check.outputs.python || steps.check.outputs.frontend
|
||||
uses: github/codeql-action/analyze@68bde559dea0fdcac2102bfdf6230c5f70eb485e # v4
|
||||
uses: github/codeql-action/analyze@9e0d7b8d25671d64c341c19c0152d693099fb5ba # v4
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
|
||||
@@ -502,6 +502,7 @@ All MCP settings go in `superset_config.py`. Defaults are defined in `superset/m
|
||||
| `MCP_DEBUG` | `False` | Enable debug logging |
|
||||
| `MCP_DEV_USERNAME` | -- | Superset username for development mode (no auth) |
|
||||
| `MCP_RBAC_ENABLED` | `True` | Enforce Superset's role-based access control on MCP tool calls. When `True`, each tool checks that the authenticated user has the required FAB permission before executing. Disable only for testing or trusted-network deployments. |
|
||||
| `MCP_DISABLED_TOOLS` | `set()` | Set of tool names to remove from the MCP server at startup. Disabled tools are never advertised to AI clients during tool discovery. Useful when a custom extension tool should replace a built-in Superset tool. See [Disabling built-in tools](#disabling-built-in-tools). |
|
||||
|
||||
### Authentication
|
||||
|
||||
@@ -825,6 +826,32 @@ while True:
|
||||
page += 1
|
||||
```
|
||||
|
||||
## Disabling built-in tools
|
||||
|
||||
If you have deployed a custom tool via a Superset extension that supersedes one of the built-in Superset tools, you can suppress the built-in version so AI clients only discover your replacement. Disabled tools are removed from the server at startup and are never advertised during tool discovery.
|
||||
|
||||
Set `MCP_DISABLED_TOOLS` in your `superset_config.py` to a set of tool names:
|
||||
|
||||
```python
|
||||
# superset_config.py
|
||||
|
||||
# Disable one tool
|
||||
MCP_DISABLED_TOOLS = {"execute_sql"}
|
||||
|
||||
# Disable multiple tools
|
||||
MCP_DISABLED_TOOLS = {"execute_sql", "health_check"}
|
||||
```
|
||||
|
||||
Tool names match the function name used in the `@tool` decorator (e.g., `execute_sql`, `list_charts`, `health_check`). Extension-prefixed tools can also be disabled using their full prefixed name:
|
||||
|
||||
```python
|
||||
MCP_DISABLED_TOOLS = {"extensions.myorg.myextension.some_tool"}
|
||||
```
|
||||
|
||||
:::note
|
||||
Specifying a tool name that does not exist logs a warning at startup and is otherwise ignored — it will not prevent the server from starting.
|
||||
:::
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
- **Use TLS** for all production MCP endpoints -- place the server behind a reverse proxy with HTTPS
|
||||
|
||||
@@ -71,9 +71,9 @@
|
||||
"@storybook/theming": "^8.6.15",
|
||||
"@superset-ui/core": "^0.20.4",
|
||||
"@swc/core": "^1.15.33",
|
||||
"antd": "^6.4.2",
|
||||
"baseline-browser-mapping": "^2.10.29",
|
||||
"caniuse-lite": "^1.0.30001792",
|
||||
"antd": "^6.4.3",
|
||||
"baseline-browser-mapping": "^2.10.30",
|
||||
"caniuse-lite": "^1.0.30001793",
|
||||
"docusaurus-plugin-openapi-docs": "^5.0.2",
|
||||
"docusaurus-theme-openapi-docs": "^5.0.2",
|
||||
"js-yaml": "^4.1.1",
|
||||
@@ -87,7 +87,7 @@
|
||||
"react-svg-pan-zoom": "^3.13.1",
|
||||
"react-table": "^7.8.0",
|
||||
"remark-import-partial": "^0.0.2",
|
||||
"reselect": "^5.1.1",
|
||||
"reselect": "^5.2.0",
|
||||
"storybook": "^8.6.18",
|
||||
"swagger-ui-react": "^5.32.5",
|
||||
"swc-loader": "^0.2.7",
|
||||
|
||||
@@ -3078,7 +3078,7 @@
|
||||
dependencies:
|
||||
"@rc-component/util" "^1.2.0"
|
||||
|
||||
"@rc-component/notification@~2.0.6":
|
||||
"@rc-component/notification@~2.0.7":
|
||||
version "2.0.7"
|
||||
resolved "https://registry.yarnpkg.com/@rc-component/notification/-/notification-2.0.7.tgz#f2450a482f87e4698285833c4a8efcac169acabb"
|
||||
integrity sha512-nqZzpf6BPdaj+3ILx7si79LLmqPKyUmQoXa+/9gg0SkH0v1DbD66oJgRMSBEVnd/zUT3D4gwxWIHUKebYf2ZXQ==
|
||||
@@ -5498,10 +5498,10 @@ ansis@^3.2.0:
|
||||
resolved "https://registry.yarnpkg.com/ansis/-/ansis-3.17.0.tgz#fa8d9c2a93fe7d1177e0c17f9eeb562a58a832d7"
|
||||
integrity sha512-0qWUglt9JEqLFr3w1I1pbrChn1grhaiAR2ocX1PP/flRmxgtwTzPFFFnfIlD6aMOLQZgSuCRlidD70lvx8yhzg==
|
||||
|
||||
antd@^6.4.2:
|
||||
version "6.4.2"
|
||||
resolved "https://registry.yarnpkg.com/antd/-/antd-6.4.2.tgz#9fc0fee455a5c56e7ec27855495eefadc8df636a"
|
||||
integrity sha512-PNJz8Vxc/mC3EsOg/h3e2YuaZduJ1RDp4RmySDuDmKPCxVgyp4Da4kB36o87p9hbLbOWdAWCKQlnyopsN8utKQ==
|
||||
antd@^6.4.3:
|
||||
version "6.4.3"
|
||||
resolved "https://registry.yarnpkg.com/antd/-/antd-6.4.3.tgz#80a7aab9c13c35daa0e0e7eea80585ba57cb7203"
|
||||
integrity sha512-6H2avkxCGfxcF67r3J2mwm9Ck50el1pks/73vfM1wDsPL/tPtj5vHuauMgJFnrqmq7CH3g8aoZ0VBQbt+jpAsw==
|
||||
dependencies:
|
||||
"@ant-design/colors" "^8.0.1"
|
||||
"@ant-design/cssinjs" "^2.1.2"
|
||||
@@ -5525,7 +5525,7 @@ antd@^6.4.2:
|
||||
"@rc-component/menu" "~1.3.0"
|
||||
"@rc-component/motion" "^1.3.2"
|
||||
"@rc-component/mutate-observer" "^2.0.1"
|
||||
"@rc-component/notification" "~2.0.6"
|
||||
"@rc-component/notification" "~2.0.7"
|
||||
"@rc-component/pagination" "~1.2.0"
|
||||
"@rc-component/picker" "~1.10.0"
|
||||
"@rc-component/progress" "~1.0.2"
|
||||
@@ -5545,7 +5545,7 @@ antd@^6.4.2:
|
||||
"@rc-component/tree-select" "~1.9.0"
|
||||
"@rc-component/trigger" "^3.9.0"
|
||||
"@rc-component/upload" "~1.1.0"
|
||||
"@rc-component/util" "^1.10.1"
|
||||
"@rc-component/util" "^1.11.0"
|
||||
clsx "^2.1.1"
|
||||
dayjs "^1.11.11"
|
||||
scroll-into-view-if-needed "^3.1.0"
|
||||
@@ -5810,10 +5810,10 @@ base64-js@^1.3.1, base64-js@^1.5.1:
|
||||
resolved "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz"
|
||||
integrity sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==
|
||||
|
||||
baseline-browser-mapping@^2.10.29, baseline-browser-mapping@^2.9.0, baseline-browser-mapping@^2.9.19:
|
||||
version "2.10.29"
|
||||
resolved "https://registry.yarnpkg.com/baseline-browser-mapping/-/baseline-browser-mapping-2.10.29.tgz#47bdc13027af28d341f367a4f35a07ce872e27b4"
|
||||
integrity sha512-Asa2krT+XTPZINCS+2QcyS8WTkObE77RwkydwF7h6DmnKqbvlalz93m/dnphUyCa6SWSP51VgtEUf2FN+gelFQ==
|
||||
baseline-browser-mapping@^2.10.30, baseline-browser-mapping@^2.9.0, baseline-browser-mapping@^2.9.19:
|
||||
version "2.10.30"
|
||||
resolved "https://registry.yarnpkg.com/baseline-browser-mapping/-/baseline-browser-mapping-2.10.30.tgz#58915c74388b05f3b3504026194ea9fa98f6e6b6"
|
||||
integrity sha512-xjOFN16Ha1+Rz4nFYKqHU/LSB+gx/Vi3yQLX7r7sAW+Wa+8hhF2h4pvqTrTMc8+WcDBEunnUurr46Jvv0jk3Vg==
|
||||
|
||||
batch@0.6.1:
|
||||
version "0.6.1"
|
||||
@@ -6051,10 +6051,10 @@ caniuse-api@^3.0.0:
|
||||
lodash.memoize "^4.1.2"
|
||||
lodash.uniq "^4.5.0"
|
||||
|
||||
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001792:
|
||||
version "1.0.30001792"
|
||||
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001792.tgz#ca8bb9be244835a335e2018272ce7223691873c5"
|
||||
integrity sha512-hVLMUZFgR4JJ6ACt1uEESvQN1/dBVqPAKY0hgrV70eN3391K6juAfTjKZLKvOMsx8PxA7gsY1/tLMMTcfFLLpw==
|
||||
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001793:
|
||||
version "1.0.30001793"
|
||||
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001793.tgz#238887ddf5fcfc8c36d872394d0a78a517312a72"
|
||||
integrity sha512-iwSsYWaCOoh26cV8NwNRViHlrfUvYsHDfRVcbtmw0Kg6PJIZZXwMkj1442FYLBGkeUf1juAsU3DTfxW579mrPA==
|
||||
|
||||
ccount@^2.0.0:
|
||||
version "2.0.1"
|
||||
@@ -13291,10 +13291,10 @@ reselect@^4.0.0:
|
||||
resolved "https://registry.npmjs.org/reselect/-/reselect-4.1.8.tgz"
|
||||
integrity sha512-ab9EmR80F/zQTMNeneUr4cv+jSwPJgIlvEmVwLerwrWVbpLlBuls9XHzIeTFy4cegU2NHBp3va0LKOzU5qFEYQ==
|
||||
|
||||
reselect@^5.1.0, reselect@^5.1.1:
|
||||
version "5.1.1"
|
||||
resolved "https://registry.npmjs.org/reselect/-/reselect-5.1.1.tgz"
|
||||
integrity sha512-K/BG6eIky/SBpzfHZv/dd+9JBFiS4SWV7FIujVyJRux6e45+73RaUHXLmIR1f7WOMaQ0U1km6qwklRQxpJJY0w==
|
||||
reselect@^5.1.0, reselect@^5.1.1, reselect@^5.2.0:
|
||||
version "5.2.0"
|
||||
resolved "https://registry.yarnpkg.com/reselect/-/reselect-5.2.0.tgz#f380ef7664332d26ea06c1cba04bdbbdcaa955f1"
|
||||
integrity sha512-AgZ3UOZm3YndfrJ4OYjgrT7bmCm/1iqkjvEfH/oYjzh6PD2qw4QuT3jjnXIrpdt4MTpMXclMT3lXbmRY+XRakw==
|
||||
|
||||
resize-observer-polyfill@1.5.1:
|
||||
version "1.5.1"
|
||||
|
||||
654
superset-frontend/package-lock.json
generated
654
superset-frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -161,8 +161,8 @@
|
||||
"@visx/scale": "^3.5.0",
|
||||
"@visx/tooltip": "^3.0.0",
|
||||
"@visx/xychart": "^3.5.1",
|
||||
"ag-grid-community": "35.2.1",
|
||||
"ag-grid-react": "35.2.1",
|
||||
"ag-grid-community": "35.3.0",
|
||||
"ag-grid-react": "35.3.0",
|
||||
"antd": "^5.26.0",
|
||||
"chrono-node": "^2.9.1",
|
||||
"classnames": "^2.2.5",
|
||||
|
||||
@@ -30,8 +30,8 @@
|
||||
"@types/json-bigint": "^1.0.4",
|
||||
"@visx/responsive": "^3.12.0",
|
||||
"ace-builds": "^1.44.0",
|
||||
"ag-grid-community": "35.2.1",
|
||||
"ag-grid-react": "35.2.1",
|
||||
"ag-grid-community": "35.3.0",
|
||||
"ag-grid-react": "35.3.0",
|
||||
"brace": "^0.11.1",
|
||||
"classnames": "^2.5.1",
|
||||
"core-js": "^3.49.0",
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
"acorn": "^8.16.0",
|
||||
"d3-array": "^3.2.4",
|
||||
"lodash": "^4.18.1",
|
||||
"zod": "^4.4.1"
|
||||
"zod": "^4.4.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@apache-superset/core": "*",
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
"@math.gl/web-mercator": "^4.1.0",
|
||||
"mapbox-gl": "^3.23.1",
|
||||
"maplibre-gl": "^5.24.0",
|
||||
"react-map-gl": "^8.1.1",
|
||||
"react-map-gl": "^8.1.0",
|
||||
"supercluster": "^8.0.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
||||
30
superset-websocket/package-lock.json
generated
30
superset-websocket/package-lock.json
generated
@@ -27,7 +27,7 @@
|
||||
"@types/ws": "^8.18.1",
|
||||
"@typescript-eslint/eslint-plugin": "^8.59.3",
|
||||
"@typescript-eslint/parser": "^8.59.3",
|
||||
"eslint": "^10.3.0",
|
||||
"eslint": "^10.4.0",
|
||||
"eslint-config-prettier": "^10.1.8",
|
||||
"eslint-plugin-lodash": "^8.0.0",
|
||||
"globals": "^17.6.0",
|
||||
@@ -802,9 +802,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@eslint/config-helpers": {
|
||||
"version": "0.5.5",
|
||||
"resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.5.5.tgz",
|
||||
"integrity": "sha512-eIJYKTCECbP/nsKaaruF6LW967mtbQbsw4JTtSVkUQc9MneSkbrgPJAbKl9nWr0ZeowV8BfsarBmPpBzGelA2w==",
|
||||
"version": "0.6.0",
|
||||
"resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.6.0.tgz",
|
||||
"integrity": "sha512-ii6Bw9jJ2zi2cWA2Z+9/QZ/+3DX6kwaV5Q986D/CdP3Lap3w/pgQZ373FV7byY/i7L4IRH/G43I5dz1ClsCbpA==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
@@ -2794,16 +2794,16 @@
|
||||
}
|
||||
},
|
||||
"node_modules/eslint": {
|
||||
"version": "10.3.0",
|
||||
"resolved": "https://registry.npmjs.org/eslint/-/eslint-10.3.0.tgz",
|
||||
"integrity": "sha512-XbEXaRva5cF0ZQB8w6MluHA0kZZfV2DuCMJ3ozyEOHLwDpZX2Lmm/7Pp0xdJmI0GL1W05VH5VwIFHEm1Vcw2gw==",
|
||||
"version": "10.4.0",
|
||||
"resolved": "https://registry.npmjs.org/eslint/-/eslint-10.4.0.tgz",
|
||||
"integrity": "sha512-loXy6bWOoP3EP6JA7jo6p5jMpBJmHmsNZM5SFRHLdh1MGOPurMnNBj4ZlAbaqUAaQWbCr7jHV4P7gzAyryZWkQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@eslint-community/eslint-utils": "^4.8.0",
|
||||
"@eslint-community/regexpp": "^4.12.2",
|
||||
"@eslint/config-array": "^0.23.5",
|
||||
"@eslint/config-helpers": "^0.5.5",
|
||||
"@eslint/config-helpers": "^0.6.0",
|
||||
"@eslint/core": "^1.2.1",
|
||||
"@eslint/plugin-kit": "^0.7.1",
|
||||
"@humanfs/node": "^0.16.6",
|
||||
@@ -7081,9 +7081,9 @@
|
||||
}
|
||||
},
|
||||
"@eslint/config-helpers": {
|
||||
"version": "0.5.5",
|
||||
"resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.5.5.tgz",
|
||||
"integrity": "sha512-eIJYKTCECbP/nsKaaruF6LW967mtbQbsw4JTtSVkUQc9MneSkbrgPJAbKl9nWr0ZeowV8BfsarBmPpBzGelA2w==",
|
||||
"version": "0.6.0",
|
||||
"resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.6.0.tgz",
|
||||
"integrity": "sha512-ii6Bw9jJ2zi2cWA2Z+9/QZ/+3DX6kwaV5Q986D/CdP3Lap3w/pgQZ373FV7byY/i7L4IRH/G43I5dz1ClsCbpA==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@eslint/core": "^1.2.1"
|
||||
@@ -8578,15 +8578,15 @@
|
||||
"dev": true
|
||||
},
|
||||
"eslint": {
|
||||
"version": "10.3.0",
|
||||
"resolved": "https://registry.npmjs.org/eslint/-/eslint-10.3.0.tgz",
|
||||
"integrity": "sha512-XbEXaRva5cF0ZQB8w6MluHA0kZZfV2DuCMJ3ozyEOHLwDpZX2Lmm/7Pp0xdJmI0GL1W05VH5VwIFHEm1Vcw2gw==",
|
||||
"version": "10.4.0",
|
||||
"resolved": "https://registry.npmjs.org/eslint/-/eslint-10.4.0.tgz",
|
||||
"integrity": "sha512-loXy6bWOoP3EP6JA7jo6p5jMpBJmHmsNZM5SFRHLdh1MGOPurMnNBj4ZlAbaqUAaQWbCr7jHV4P7gzAyryZWkQ==",
|
||||
"dev": true,
|
||||
"requires": {
|
||||
"@eslint-community/eslint-utils": "^4.8.0",
|
||||
"@eslint-community/regexpp": "^4.12.2",
|
||||
"@eslint/config-array": "^0.23.5",
|
||||
"@eslint/config-helpers": "^0.5.5",
|
||||
"@eslint/config-helpers": "^0.6.0",
|
||||
"@eslint/core": "^1.2.1",
|
||||
"@eslint/plugin-kit": "^0.7.1",
|
||||
"@humanfs/node": "^0.16.6",
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
"@types/ws": "^8.18.1",
|
||||
"@typescript-eslint/eslint-plugin": "^8.59.3",
|
||||
"@typescript-eslint/parser": "^8.59.3",
|
||||
"eslint": "^10.3.0",
|
||||
"eslint": "^10.4.0",
|
||||
"eslint-config-prettier": "^10.1.8",
|
||||
"eslint-plugin-lodash": "^8.0.0",
|
||||
"globals": "^17.6.0",
|
||||
|
||||
@@ -55,20 +55,23 @@ class StreamingCSVExportCommand(BaseStreamingCSVExportCommand):
|
||||
"""Validate permissions and query context."""
|
||||
self._query_context.raise_for_access()
|
||||
|
||||
def _get_sql_and_database(self) -> tuple[str, Any]:
|
||||
def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]:
|
||||
"""
|
||||
Get the SQL query and database for chart export.
|
||||
Get the SQL query, database, catalog, and schema for chart export.
|
||||
|
||||
Returns:
|
||||
Tuple of (sql_query, database_object)
|
||||
Tuple of (sql_query, database_object, catalog, schema)
|
||||
"""
|
||||
# Get datasource and generate SQL query
|
||||
# Note: datasource should already be attached to a session from query_context
|
||||
datasource = self._query_context.datasource
|
||||
query_obj = self._query_context.queries[0]
|
||||
sql_query = datasource.get_query_str(query_obj.to_dict())
|
||||
database = getattr(datasource, "database", None)
|
||||
catalog = getattr(datasource, "catalog", None)
|
||||
schema = getattr(datasource, "schema", None)
|
||||
|
||||
return sql_query, getattr(datasource, "database", None)
|
||||
return sql_query, database, catalog, schema
|
||||
|
||||
def _get_row_limit(self) -> int | None:
|
||||
"""
|
||||
|
||||
@@ -87,12 +87,12 @@ class StreamingSqlResultExportCommand(BaseStreamingCSVExportCommand):
|
||||
status=403,
|
||||
) from ex
|
||||
|
||||
def _get_sql_and_database(self) -> tuple[str, Any]:
|
||||
def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]:
|
||||
"""
|
||||
Get the SQL query and database for SQL Lab export.
|
||||
Get the SQL query, database, catalog, and schema for SQL Lab export.
|
||||
|
||||
Returns:
|
||||
Tuple of (sql_query, database_object)
|
||||
Tuple of (sql_query, database_object, catalog, schema)
|
||||
"""
|
||||
assert self._query is not None
|
||||
|
||||
@@ -103,7 +103,7 @@ class StreamingSqlResultExportCommand(BaseStreamingCSVExportCommand):
|
||||
# Get the SQL query
|
||||
sql = select_sql or executed_sql
|
||||
|
||||
return sql, database
|
||||
return sql, database, self._query.catalog, self._query.schema
|
||||
|
||||
def _get_row_limit(self) -> int | None:
|
||||
"""
|
||||
|
||||
@@ -79,12 +79,12 @@ class BaseStreamingCSVExportCommand(BaseCommand):
|
||||
self._current_app = app._get_current_object()
|
||||
|
||||
@abstractmethod
|
||||
def _get_sql_and_database(self) -> tuple[str, Any]:
|
||||
def _get_sql_and_database(self) -> tuple[str, Any, str | None, str | None]:
|
||||
"""
|
||||
Get the SQL query and database for execution.
|
||||
Get the SQL query, database, catalog, and schema for execution.
|
||||
|
||||
Returns:
|
||||
Tuple of (sql_query, database_object)
|
||||
Tuple of (sql_query, database_object, catalog, schema)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -150,7 +150,12 @@ class BaseStreamingCSVExportCommand(BaseCommand):
|
||||
yield remaining_data, row_count, data_bytes
|
||||
|
||||
def _execute_query_and_stream(
|
||||
self, sql: str, database: Any, limit: int | None
|
||||
self,
|
||||
sql: str,
|
||||
database: Any,
|
||||
limit: int | None,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Execute query with streaming and yield CSV chunks."""
|
||||
start_time = time.time()
|
||||
@@ -160,8 +165,9 @@ class BaseStreamingCSVExportCommand(BaseCommand):
|
||||
# Merge database to prevent DetachedInstanceError
|
||||
merged_database = session.merge(database)
|
||||
|
||||
# Execute query with streaming
|
||||
with merged_database.get_sqla_engine() as engine:
|
||||
with merged_database.get_sqla_engine(
|
||||
catalog=catalog, schema=schema
|
||||
) as engine:
|
||||
with engine.connect() as connection:
|
||||
result_proxy = connection.execution_options(
|
||||
stream_results=True
|
||||
@@ -209,7 +215,7 @@ class BaseStreamingCSVExportCommand(BaseCommand):
|
||||
"""
|
||||
# Load all needed data while session is still active
|
||||
# to avoid DetachedInstanceError
|
||||
sql, database = self._get_sql_and_database()
|
||||
sql, database, catalog, schema = self._get_sql_and_database()
|
||||
limit = self._get_row_limit()
|
||||
# Capture flask.g attributes to preserve request-scoped data
|
||||
# when the streaming generator runs in a new app context.
|
||||
@@ -222,7 +228,9 @@ class BaseStreamingCSVExportCommand(BaseCommand):
|
||||
with self._current_app.app_context():
|
||||
with preserve_g_context(captured_g):
|
||||
try:
|
||||
yield from self._execute_query_and_stream(sql, database, limit)
|
||||
yield from self._execute_query_and_stream(
|
||||
sql, database, limit, catalog, schema
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error in streaming CSV generator: %s", e)
|
||||
import traceback
|
||||
|
||||
@@ -30,18 +30,68 @@ from fastmcp.server.middleware import Middleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prose snippets that reference get_instance_info.
|
||||
# These are included in the generated instructions only when that tool is
|
||||
# enabled; each snippet is a plain string constant so they can be read
|
||||
# independently of the filtering logic in get_default_instructions().
|
||||
# ---------------------------------------------------------------------------
|
||||
_SNIPPET_FEATURE_AVAILABILITY = (
|
||||
"Feature Availability:\n"
|
||||
"- Call get_instance_info to discover accessible menus for the current user.\n"
|
||||
"- Do NOT assume features exist; always check get_instance_info first.\n"
|
||||
"\n"
|
||||
)
|
||||
_SNIPPET_INSTANCE_INFO_ROLE_BULLET = (
|
||||
"- get_instance_info returns current_user.roles"
|
||||
' (e.g., ["Admin"], ["Alpha"], ["Viewer"]).\n'
|
||||
)
|
||||
_SNIPPET_ACCESSIBLE_MENUS_BULLET = (
|
||||
"- If you are unsure about a user's capabilities,"
|
||||
" check their accessible_menus in\n"
|
||||
" feature_availability from get_instance_info.\n"
|
||||
)
|
||||
_SNIPPET_UNSURE_GUIDANCE = (
|
||||
"\nIf you are unsure which tool to use, start with get_instance_info\n"
|
||||
"or use the quickstart prompt for an interactive guide.\n"
|
||||
)
|
||||
_SNIPPET_CONNECT_GUIDANCE = (
|
||||
"\nWhen you first connect, call get_instance_info to learn the user's identity.\n"
|
||||
"Greet them by their first name (from current_user) and offer to help.\n"
|
||||
)
|
||||
|
||||
def get_default_instructions(branding: str = "Apache Superset") -> str:
|
||||
|
||||
def get_default_instructions(
|
||||
branding: str = "Apache Superset",
|
||||
disabled_tools: set[str] | None = None,
|
||||
) -> str:
|
||||
"""Get default instructions with configurable branding.
|
||||
|
||||
Tool bullet-point lines for any tool name in ``disabled_tools`` are
|
||||
omitted so that LLM clients are never told to call a tool that has been
|
||||
suppressed via ``MCP_DISABLED_TOOLS``.
|
||||
|
||||
Args:
|
||||
branding: Product name to use in instructions
|
||||
(e.g., "ACME Analytics", "Apache Superset")
|
||||
disabled_tools: Set of tool names to omit from the tool listing.
|
||||
When ``None`` (default) all tools are included.
|
||||
|
||||
Returns:
|
||||
Formatted instructions string with branding applied
|
||||
"""
|
||||
return f"""
|
||||
_disabled = disabled_tools or set()
|
||||
|
||||
# Prose sections that reference get_instance_info are omitted when that
|
||||
# tool is disabled so the LLM is never directed to call a removed tool.
|
||||
_show = "get_instance_info" not in _disabled
|
||||
_feature_availability = _SNIPPET_FEATURE_AVAILABILITY if _show else ""
|
||||
_instance_info_role_bullet = _SNIPPET_INSTANCE_INFO_ROLE_BULLET if _show else ""
|
||||
_accessible_menus_bullet = _SNIPPET_ACCESSIBLE_MENUS_BULLET if _show else ""
|
||||
_unsure_guidance = _SNIPPET_UNSURE_GUIDANCE if _show else ""
|
||||
_connect_guidance = _SNIPPET_CONNECT_GUIDANCE if _show else ""
|
||||
|
||||
instructions = f"""
|
||||
You are connected to the {branding} MCP (Model Context Protocol) service.
|
||||
This service provides programmatic access to {branding} dashboards, charts, datasets,
|
||||
SQL Lab, and instance metadata via a comprehensive set of tools.
|
||||
@@ -302,13 +352,8 @@ Input format:
|
||||
- Tool request parameters accept structured objects (dicts/JSON)
|
||||
- FastMCP 3.1+ handles Pydantic BaseModel parameters natively
|
||||
|
||||
Feature Availability:
|
||||
- Call get_instance_info to discover accessible menus for the current user.
|
||||
- Do NOT assume features exist; always check get_instance_info first.
|
||||
|
||||
Permission Awareness:
|
||||
- get_instance_info returns current_user.roles (e.g., ["Admin"], ["Alpha"], ["Viewer"]).
|
||||
- ALWAYS check the user's roles BEFORE suggesting write operations (creating datasets,
|
||||
{_feature_availability}Permission Awareness:
|
||||
{_instance_info_role_bullet}- ALWAYS check the user's roles BEFORE suggesting write operations (creating datasets,
|
||||
charts, dashboards, or running SQL).
|
||||
- Do NOT disclose dashboard access lists, dashboard owners, chart owners, dataset
|
||||
owners, workspace admins, or other users' names, usernames, email addresses,
|
||||
@@ -332,15 +377,38 @@ Permission Awareness:
|
||||
1. Explain that they may not have access to the requested resources
|
||||
2. Suggest they ask a workspace admin to grant them access or share content with them
|
||||
3. Offer to help with what they CAN do (e.g., viewing dashboards they have access to)
|
||||
- If you are unsure about a user's capabilities, check their accessible_menus in
|
||||
feature_availability from get_instance_info.
|
||||
{_accessible_menus_bullet}{_unsure_guidance}{_connect_guidance}"""
|
||||
if not _disabled:
|
||||
return instructions
|
||||
|
||||
If you are unsure which tool to use, start with get_instance_info
|
||||
or use the quickstart prompt for an interactive guide.
|
||||
|
||||
When you first connect, call get_instance_info to learn the user's identity.
|
||||
Greet them by their first name (from current_user) and offer to help.
|
||||
"""
|
||||
# Strip any line that mentions a disabled tool — this covers both the
|
||||
# "- tool_name: ..." bullet entries and all prose/workflow references
|
||||
# (request wrapper examples, workflow steps, CRITICAL RULES, etc.).
|
||||
# Tool names are specific enough (e.g. execute_sql, generate_chart) that
|
||||
# false positives are not a practical concern.
|
||||
#
|
||||
# Bullet continuation lines (indented lines belonging to a disabled bullet)
|
||||
# are also dropped via the skip_continuation flag.
|
||||
filtered_lines = []
|
||||
skip_continuation = False
|
||||
for line in instructions.splitlines(keepends=True):
|
||||
stripped = line.lstrip()
|
||||
if stripped.startswith("- "):
|
||||
tool_part = stripped[2:].split(":")[0].strip()
|
||||
if tool_part in _disabled:
|
||||
skip_continuation = True
|
||||
continue
|
||||
skip_continuation = False
|
||||
elif skip_continuation and stripped and not stripped.startswith("- "):
|
||||
# Indented continuation line of the previous disabled bullet — skip
|
||||
continue
|
||||
else:
|
||||
skip_continuation = False
|
||||
# Drop any prose line that names a disabled tool
|
||||
if any(tool in line for tool in _disabled):
|
||||
continue
|
||||
filtered_lines.append(line)
|
||||
return "".join(filtered_lines)
|
||||
|
||||
|
||||
# For backwards compatibility, keep DEFAULT_INSTRUCTIONS pointing to default branding
|
||||
@@ -569,6 +637,25 @@ from superset.mcp_service.system.tool import ( # noqa: F401, E402
|
||||
)
|
||||
|
||||
|
||||
def _remove_disabled_tools(disabled_tools: set[str]) -> None:
|
||||
"""Remove tools listed in MCP_DISABLED_TOOLS from the global MCP instance.
|
||||
|
||||
Disabled tools are removed before the server starts serving requests so they
|
||||
are never advertised to AI clients during tool discovery. Users configure
|
||||
this via MCP_DISABLED_TOOLS in superset_config.py.
|
||||
"""
|
||||
for tool_name in disabled_tools:
|
||||
try:
|
||||
mcp.local_provider.remove_tool(tool_name)
|
||||
logger.info("Disabled MCP tool: %s (MCP_DISABLED_TOOLS)", tool_name)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"MCP_DISABLED_TOOLS: tool %r not found — "
|
||||
"check the tool name is correct",
|
||||
tool_name,
|
||||
)
|
||||
|
||||
|
||||
def init_fastmcp_server(
|
||||
name: str | None = None,
|
||||
instructions: str | None = None,
|
||||
@@ -608,8 +695,14 @@ def init_fastmcp_server(
|
||||
# Apply branding defaults if not explicitly provided
|
||||
if name is None:
|
||||
name = default_name
|
||||
|
||||
# Remove disabled tools BEFORE generating instructions so that the
|
||||
# instructions never advertise tools that clients cannot actually call.
|
||||
disabled_tools: set[str] = flask_app.config.get("MCP_DISABLED_TOOLS", set())
|
||||
_remove_disabled_tools(disabled_tools)
|
||||
|
||||
if instructions is None:
|
||||
instructions = get_default_instructions(branding)
|
||||
instructions = get_default_instructions(branding, disabled_tools)
|
||||
|
||||
# Configure the global mcp instance with provided settings.
|
||||
# Tools are already registered on this instance via @tool decorator imports above.
|
||||
|
||||
@@ -56,6 +56,19 @@ MCP_DEBUG = False
|
||||
# against the FAB security_manager before execution.
|
||||
MCP_RBAC_ENABLED = True
|
||||
|
||||
# MCP Disabled Tools - a set of tool names to remove from the MCP server at
|
||||
# startup. Disabled tools are silently omitted from tool discovery, so AI
|
||||
# clients never see them. Use this when a Superset-provided tool conflicts with
|
||||
# a custom tool added via an extension and you want to suppress the built-in
|
||||
# version.
|
||||
#
|
||||
# Example:
|
||||
# MCP_DISABLED_TOOLS = {"execute_sql", "health_check"}
|
||||
#
|
||||
# Extension-prefixed tools can also be disabled using their full name:
|
||||
# MCP_DISABLED_TOOLS = {"extensions.myorg.myext.some_tool"}
|
||||
MCP_DISABLED_TOOLS: set[str] = set()
|
||||
|
||||
# MCP JWT Debug Errors - controls server-side JWT debug logging.
|
||||
# When False (default), uses the default JWTVerifier with minimal logging.
|
||||
# When True, uses DetailedJWTVerifier with tiered logging:
|
||||
@@ -402,6 +415,7 @@ def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
"MCP_SERVICE_PORT": MCP_SERVICE_PORT,
|
||||
"MCP_DEBUG": MCP_DEBUG,
|
||||
"MCP_RBAC_ENABLED": MCP_RBAC_ENABLED,
|
||||
"MCP_DISABLED_TOOLS": set(MCP_DISABLED_TOOLS),
|
||||
**MCP_SESSION_CONFIG,
|
||||
**MCP_CSRF_CONFIG,
|
||||
}
|
||||
|
||||
@@ -468,13 +468,34 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
|
||||
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
|
||||
with engine_context_manager(self, catalog, schema):
|
||||
with check_for_oauth2(self):
|
||||
yield self._get_sqla_engine(
|
||||
engine = self._get_sqla_engine(
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
nullpool=nullpool,
|
||||
source=source,
|
||||
sqlalchemy_uri=sqlalchemy_uri,
|
||||
)
|
||||
prequeries = self.db_engine_spec.get_prequeries(
|
||||
database=self,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
)
|
||||
if prequeries:
|
||||
# SQLAlchemy connect event: runs prequeries on every new
|
||||
# DBAPI connection (e.g. SET search_path for PostgreSQL).
|
||||
def run_prequeries(
|
||||
dbapi_connection: Any,
|
||||
connection_record: Any, # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
for prequery in prequeries:
|
||||
cursor.execute(prequery)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
sqla.event.listen(engine, "connect", run_prequeries)
|
||||
yield engine
|
||||
|
||||
def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
|
||||
self,
|
||||
@@ -583,15 +604,6 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
|
||||
) as engine:
|
||||
with check_for_oauth2(self):
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
# pre-session queries are used to set the selected catalog/schema
|
||||
for prequery in self.db_engine_spec.get_prequeries(
|
||||
database=self,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(prequery)
|
||||
|
||||
yield conn
|
||||
|
||||
def get_default_catalog(self) -> str | None:
|
||||
|
||||
@@ -38,7 +38,7 @@ from superset.superset_typing import OAuth2ClientConfig, OAuth2State
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||
from superset.models.core import Database
|
||||
|
||||
JWT_EXPIRATION = timedelta(minutes=5)
|
||||
|
||||
@@ -116,7 +116,7 @@ def get_oauth2_access_token(
|
||||
return token.access_token
|
||||
|
||||
if token.refresh_token:
|
||||
return refresh_oauth2_token(config, database_id, user_id, db_engine_spec, token)
|
||||
return refresh_oauth2_token(config, database_id, user_id, db_engine_spec)
|
||||
|
||||
# since the access token is expired and there's no refresh token, delete the entry
|
||||
db.session.delete(token)
|
||||
@@ -129,8 +129,10 @@ def refresh_oauth2_token(
|
||||
database_id: int,
|
||||
user_id: int,
|
||||
db_engine_spec: type[BaseEngineSpec],
|
||||
token: DatabaseUserOAuth2Tokens,
|
||||
) -> str | None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.models.core import DatabaseUserOAuth2Tokens
|
||||
|
||||
# Use longer TTL for OAuth2 token refresh (may involve network calls)
|
||||
with DistributedLock(
|
||||
namespace="refresh_oauth2_token",
|
||||
@@ -138,6 +140,22 @@ def refresh_oauth2_token(
|
||||
user_id=user_id,
|
||||
database_id=database_id,
|
||||
):
|
||||
# Short circuit in case another request already deleted the token
|
||||
token = (
|
||||
db.session.query(DatabaseUserOAuth2Tokens)
|
||||
.filter_by(user_id=user_id, database_id=database_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
if token.access_token and datetime.now() < token.access_token_expiration:
|
||||
return token.access_token
|
||||
|
||||
if not token.refresh_token:
|
||||
db.session.delete(token)
|
||||
return None
|
||||
|
||||
try:
|
||||
token_response = db_engine_spec.get_oauth2_fresh_token(
|
||||
config,
|
||||
|
||||
@@ -760,6 +760,54 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
|
||||
db.session.delete(dashboard)
|
||||
db.session.commit()
|
||||
|
||||
def test_get_dashboards_admin_sees_existing_dashboards(self):
|
||||
"""Regression for #25890: GET /api/v1/dashboard/ as an Admin user should
|
||||
return existing dashboards, not an empty list. The original report
|
||||
showed an Admin getting {"count": 0, "ids": []} despite dashboards
|
||||
existing in the database."""
|
||||
admin = self.get_user("admin")
|
||||
dashboard = self.insert_dashboard(
|
||||
"regression_25890_dashboard", "regression-25890", [admin.id]
|
||||
)
|
||||
try:
|
||||
self.login(ADMIN_USERNAME)
|
||||
rv = self.client.get("api/v1/dashboard/")
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["count"] >= 1, (
|
||||
f"Admin received empty dashboard list despite "
|
||||
f"{dashboard.dashboard_title!r} existing; see issue #25890"
|
||||
)
|
||||
titles = [d["dashboard_title"] for d in data["result"]]
|
||||
assert dashboard.dashboard_title in titles, (
|
||||
f"Admin list missing the inserted dashboard. Got titles: {titles}"
|
||||
)
|
||||
finally:
|
||||
db.session.delete(dashboard)
|
||||
db.session.commit()
|
||||
|
||||
def test_get_charts_admin_sees_existing_charts(self):
|
||||
"""Regression for #25890: GET /api/v1/chart/ as an Admin user should
|
||||
return existing charts, not an empty list."""
|
||||
admin = self.get_user("admin")
|
||||
chart = self.insert_chart("regression_25890_chart", [admin.id], 1, params="{}")
|
||||
try:
|
||||
self.login(ADMIN_USERNAME)
|
||||
rv = self.client.get("api/v1/chart/")
|
||||
assert rv.status_code == 200
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["count"] >= 1, (
|
||||
f"Admin received empty chart list despite "
|
||||
f"{chart.slice_name!r} existing; see issue #25890"
|
||||
)
|
||||
names = [c["slice_name"] for c in data["result"]]
|
||||
assert chart.slice_name in names, (
|
||||
f"Admin list missing the inserted chart. Got slice_names: {names}"
|
||||
)
|
||||
finally:
|
||||
db.session.delete(chart)
|
||||
db.session.commit()
|
||||
|
||||
def test_get_dashboards_filter(self):
|
||||
"""
|
||||
Dashboard API: Test get dashboards filter
|
||||
|
||||
@@ -30,6 +30,7 @@ from superset.utils import json
|
||||
from tests.conftest import with_config
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.constants import ADMIN_USERNAME, GAMMA_USERNAME
|
||||
from tests.integration_tests.test_app import app
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices, # noqa: F401
|
||||
load_birth_names_data, # noqa: F401
|
||||
@@ -402,3 +403,42 @@ class TestSecurityRolesApi(SupersetTestCase):
|
||||
assert sorted(role2_api["user_ids"]) == role2_expected["user_ids"]
|
||||
assert sorted(role2_api["permission_ids"]) == role2_expected["permission_ids"]
|
||||
assert role2_api["group_ids"] == role2_expected["group_ids"]
|
||||
|
||||
|
||||
class TestLogoutSessionInvalidation(SupersetTestCase):
|
||||
"""Regression for #24713: a session cookie captured pre-logout must not grant
|
||||
access after the user logs out. The original report describes copying the
|
||||
session cookie out, calling /logout/, and successfully reusing the cookie in
|
||||
a second browser to bypass authentication."""
|
||||
|
||||
def test_session_cookie_invalidated_after_logout(self):
|
||||
self.login(ADMIN_USERNAME)
|
||||
|
||||
resp_authed = self.client.get("api/v1/dashboard/", follow_redirects=False)
|
||||
assert resp_authed.status_code == 200, (
|
||||
f"Login did not yield an authenticated session "
|
||||
f"(got {resp_authed.status_code})"
|
||||
)
|
||||
|
||||
# Werkzeug 2.3+ exposes the test client's cookies on `_cookies` as a
|
||||
# mapping keyed by (domain, path, key). Snapshot the session cookie
|
||||
# value — this is what a malicious actor would copy out of a browser.
|
||||
captured = None
|
||||
for cookie in self.client._cookies.values():
|
||||
if cookie.key == "session":
|
||||
captured = cookie.value
|
||||
break
|
||||
assert captured, "expected a session cookie after login"
|
||||
|
||||
self.client.get("/logout/", follow_redirects=True)
|
||||
|
||||
# Replay the captured cookie in a fresh client (simulates importing
|
||||
# the cookie into a second browser).
|
||||
replay_client = app.test_client()
|
||||
replay_client.set_cookie("session", captured, domain="localhost")
|
||||
|
||||
resp_replay = replay_client.get("api/v1/dashboard/", follow_redirects=False)
|
||||
assert resp_replay.status_code != 200, (
|
||||
f"Captured session cookie was still accepted after logout "
|
||||
f"(status={resp_replay.status_code}); see issue #24713"
|
||||
)
|
||||
|
||||
@@ -25,7 +25,10 @@ from superset.commands.chart.data.streaming_export_command import (
|
||||
|
||||
|
||||
def _setup_chart_mocks(
|
||||
mocker: MockerFixture, sql: str = "SELECT * FROM test"
|
||||
mocker: MockerFixture,
|
||||
sql: str = "SELECT * FROM test",
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> tuple[MockerFixture, MockerFixture, MockerFixture]:
|
||||
"""Set up common mocks for chart streaming export tests."""
|
||||
mock_db = mocker.patch("superset.commands.streaming_export.base.db")
|
||||
@@ -36,6 +39,8 @@ def _setup_chart_mocks(
|
||||
datasource = mocker.MagicMock()
|
||||
datasource.get_query_str.return_value = sql
|
||||
datasource.database = mocker.MagicMock()
|
||||
datasource.catalog = catalog
|
||||
datasource.schema = schema
|
||||
query_context.datasource = datasource
|
||||
query_context.queries = [mocker.MagicMock()]
|
||||
mock_session.merge.return_value = datasource.database
|
||||
@@ -256,3 +261,38 @@ def test_empty_result_set(mocker: MockerFixture) -> None:
|
||||
lines = [line.strip() for line in csv_data.strip().split("\n")]
|
||||
assert len(lines) == 1
|
||||
assert lines[0] == "col1,col2"
|
||||
|
||||
|
||||
def test_catalog_and_schema_passed_to_engine(mocker: MockerFixture) -> None:
|
||||
"""Test that catalog and schema are forwarded to get_sqla_engine.
|
||||
|
||||
Prequeries (e.g. SET search_path for PostgreSQL) are now run automatically
|
||||
via a connect event listener registered inside get_sqla_engine, not by the
|
||||
streaming command itself.
|
||||
"""
|
||||
mock_db, query_context, datasource = _setup_chart_mocks(
|
||||
mocker, catalog="my_catalog", schema="my_schema"
|
||||
)
|
||||
|
||||
mock_result = mocker.MagicMock()
|
||||
mock_result.keys.return_value = ["col1"]
|
||||
mock_result.fetchmany.side_effect = [[("val",)], []]
|
||||
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.execution_options.return_value.execute.return_value = mock_result
|
||||
mock_connection.__enter__.return_value = mock_connection
|
||||
mock_connection.__exit__.return_value = None
|
||||
|
||||
mock_engine = mocker.MagicMock()
|
||||
mock_engine.connect.return_value = mock_connection
|
||||
datasource.database.get_sqla_engine.return_value.__enter__.return_value = (
|
||||
mock_engine
|
||||
)
|
||||
|
||||
command = StreamingCSVExportCommand(query_context)
|
||||
list(command.run()())
|
||||
|
||||
datasource.database.get_sqla_engine.assert_called_once_with(
|
||||
catalog="my_catalog",
|
||||
schema="my_schema",
|
||||
)
|
||||
|
||||
@@ -55,6 +55,8 @@ def mock_query():
|
||||
query.select_sql = None
|
||||
query.executed_sql = "SELECT * FROM test_table"
|
||||
query.limiting_factor = LimitingFactor.NOT_LIMITED
|
||||
query.catalog = None
|
||||
query.schema = "public"
|
||||
query.database = MagicMock()
|
||||
query.database.db_engine_spec = MagicMock()
|
||||
query.database.db_engine_spec.engine = "postgresql"
|
||||
@@ -538,3 +540,40 @@ def test_null_values_handling(mocker, mock_query):
|
||||
assert "1,,100" in csv_data
|
||||
assert "2,test," in csv_data
|
||||
assert ",," in csv_data
|
||||
|
||||
|
||||
def test_catalog_and_schema_passed_to_engine(mocker, mock_query, mock_result_proxy):
|
||||
"""Test that catalog and schema are forwarded to get_sqla_engine.
|
||||
|
||||
Prequeries (e.g. SET search_path for PostgreSQL) are now run automatically
|
||||
via a connect event listener registered inside get_sqla_engine, not by the
|
||||
streaming command itself.
|
||||
"""
|
||||
mock_query.select_sql = "SELECT * FROM test"
|
||||
mock_query.catalog = "my_catalog"
|
||||
mock_query.schema = "my_schema"
|
||||
|
||||
mock_db, mock_session = _setup_sqllab_mocks(mocker, mock_query)
|
||||
|
||||
mock_connection = MagicMock()
|
||||
mock_connection.execution_options.return_value.execute.return_value = (
|
||||
mock_result_proxy
|
||||
)
|
||||
mock_connection.__enter__.return_value = mock_connection
|
||||
mock_connection.__exit__.return_value = None
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.connect.return_value = mock_connection
|
||||
mock_query.database.get_sqla_engine.return_value.__enter__.return_value = (
|
||||
mock_engine
|
||||
)
|
||||
|
||||
command = StreamingSqlResultExportCommand("test_client_123")
|
||||
command.validate()
|
||||
|
||||
list(command.run()())
|
||||
|
||||
mock_query.database.get_sqla_engine.assert_called_once_with(
|
||||
catalog="my_catalog",
|
||||
schema="my_schema",
|
||||
)
|
||||
|
||||
@@ -105,11 +105,22 @@ def test_get_default_instructions_forbid_disclosing_other_user_access_or_roles()
|
||||
assert "direct them to their workspace admin" in instructions
|
||||
|
||||
|
||||
def _mock_flask_config(app_name: str) -> MagicMock:
|
||||
"""Return a Flask app mock whose config.get() returns correct types per key."""
|
||||
mock = MagicMock()
|
||||
mock.config.get.side_effect = lambda key, default=None: (
|
||||
app_name
|
||||
if key == "APP_NAME"
|
||||
else set()
|
||||
if key == "MCP_DISABLED_TOOLS"
|
||||
else default
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def test_init_fastmcp_server_with_default_app_name():
|
||||
"""Test that default APP_NAME produces Superset branding."""
|
||||
# Mock Flask app config with default APP_NAME
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = "Superset"
|
||||
mock_flask_app = _mock_flask_config("Superset")
|
||||
|
||||
# Patch at the import location to avoid actual Flask app creation
|
||||
with patch.dict(
|
||||
@@ -127,9 +138,7 @@ def test_init_fastmcp_server_with_default_app_name():
|
||||
def test_init_fastmcp_server_with_custom_app_name():
|
||||
"""Test that custom APP_NAME produces branded instructions."""
|
||||
custom_app_name = "ACME Analytics"
|
||||
# Mock Flask app config with custom APP_NAME
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = custom_app_name
|
||||
mock_flask_app = _mock_flask_config(custom_app_name)
|
||||
|
||||
# Patch at the import location to avoid actual Flask app creation
|
||||
with patch.dict(
|
||||
@@ -149,10 +158,7 @@ def test_init_fastmcp_server_derives_server_name_from_app_name():
|
||||
"""Test that server name is derived from APP_NAME."""
|
||||
custom_app_name = "DataViz Platform"
|
||||
expected_server_name = f"{custom_app_name} MCP Server"
|
||||
|
||||
# Mock Flask app config
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = custom_app_name
|
||||
mock_flask_app = _mock_flask_config(custom_app_name)
|
||||
|
||||
# Patch at the import location to avoid actual Flask app creation
|
||||
with patch.dict(
|
||||
@@ -168,8 +174,7 @@ def test_init_fastmcp_server_derives_server_name_from_app_name():
|
||||
|
||||
def test_init_fastmcp_server_applies_auth_to_global_instance():
|
||||
"""Test that auth is applied to the global mcp instance, not a new one."""
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = "Superset"
|
||||
mock_flask_app = _mock_flask_config("Superset")
|
||||
mock_auth = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
@@ -187,8 +192,7 @@ def test_init_fastmcp_server_applies_auth_to_global_instance():
|
||||
|
||||
def test_init_fastmcp_server_applies_middleware_to_global_instance():
|
||||
"""Test that middleware is added to the global mcp instance."""
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = "Superset"
|
||||
mock_flask_app = _mock_flask_config("Superset")
|
||||
mock_mw = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
@@ -200,3 +204,23 @@ def test_init_fastmcp_server_applies_middleware_to_global_instance():
|
||||
|
||||
# Middleware should be added via add_middleware
|
||||
mock_mcp.add_middleware.assert_called_once_with(mock_mw)
|
||||
|
||||
|
||||
def test_get_mcp_config_includes_mcp_disabled_tools_key() -> None:
|
||||
"""get_mcp_config must include MCP_DISABLED_TOOLS in its defaults dict so the
|
||||
key is available in flask_app.config for the standalone server startup path."""
|
||||
from superset.mcp_service.mcp_config import get_mcp_config
|
||||
|
||||
config = get_mcp_config()
|
||||
assert "MCP_DISABLED_TOOLS" in config
|
||||
assert config["MCP_DISABLED_TOOLS"] == set()
|
||||
|
||||
|
||||
def test_get_mcp_config_respects_app_config_override() -> None:
|
||||
"""When app_config provides MCP_DISABLED_TOOLS, it takes precedence over the
|
||||
module-level default."""
|
||||
from superset.mcp_service.mcp_config import get_mcp_config
|
||||
|
||||
custom = {"execute_sql", "health_check"}
|
||||
config = get_mcp_config({"MCP_DISABLED_TOOLS": custom})
|
||||
assert config["MCP_DISABLED_TOOLS"] == custom
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
"""Test MCP app imports and tool/prompt registration."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from superset.mcp_service.app import get_default_instructions, init_fastmcp_server, mcp
|
||||
|
||||
|
||||
def _run(coro):
|
||||
@@ -95,3 +99,188 @@ def test_mcp_packages_discoverable_by_setuptools():
|
||||
f"MCP sub-packages missing __init__.py (will be excluded from "
|
||||
f"setuptools distributions): {missing}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP_DISABLED_TOOLS tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_flask_app_mock(disabled_tools: set[str]) -> MagicMock:
|
||||
"""Return a minimal Flask app mock with MCP_DISABLED_TOOLS configured."""
|
||||
flask_app = MagicMock()
|
||||
flask_app.config.get.side_effect = lambda key, default=None: (
|
||||
disabled_tools if key == "MCP_DISABLED_TOOLS" else default
|
||||
)
|
||||
return flask_app
|
||||
|
||||
|
||||
def test_disabled_tools_are_removed_from_mcp_server() -> None:
|
||||
"""Tools listed in MCP_DISABLED_TOOLS are removed before the server starts."""
|
||||
|
||||
flask_app = _make_flask_app_mock({"health_check", "list_charts"})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.flask_singleton.app",
|
||||
flask_app,
|
||||
),
|
||||
patch.object(mcp.local_provider, "remove_tool") as mock_remove,
|
||||
):
|
||||
init_fastmcp_server()
|
||||
|
||||
removed = {call.args[0] for call in mock_remove.call_args_list}
|
||||
assert "health_check" in removed
|
||||
assert "list_charts" in removed
|
||||
|
||||
|
||||
def test_unknown_disabled_tool_logs_warning_not_raises(caplog) -> None:
|
||||
"""An unknown tool name in MCP_DISABLED_TOOLS logs a warning and does not crash."""
|
||||
|
||||
flask_app = _make_flask_app_mock({"nonexistent_tool_xyz"})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.flask_singleton.app",
|
||||
flask_app,
|
||||
),
|
||||
patch.object(
|
||||
mcp.local_provider,
|
||||
"remove_tool",
|
||||
side_effect=KeyError("nonexistent_tool_xyz"),
|
||||
),
|
||||
caplog.at_level(logging.WARNING, logger="superset.mcp_service.app"),
|
||||
):
|
||||
# Must not raise
|
||||
init_fastmcp_server()
|
||||
|
||||
assert "nonexistent_tool_xyz" in caplog.text
|
||||
assert "MCP_DISABLED_TOOLS" in caplog.text
|
||||
|
||||
|
||||
def test_empty_disabled_tools_removes_nothing() -> None:
|
||||
"""An empty MCP_DISABLED_TOOLS set leaves all tools registered."""
|
||||
|
||||
flask_app = _make_flask_app_mock(set())
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.flask_singleton.app",
|
||||
flask_app,
|
||||
),
|
||||
patch.object(mcp.local_provider, "remove_tool") as mock_remove,
|
||||
):
|
||||
init_fastmcp_server()
|
||||
|
||||
mock_remove.assert_not_called()
|
||||
|
||||
|
||||
def test_disabled_tools_read_from_flask_app_config() -> None:
|
||||
"""MCP_DISABLED_TOOLS is read from flask_app.config, matching the standard
|
||||
Superset pattern where users set overrides in superset_config.py, which
|
||||
create_app() loads into Flask config before any command runs."""
|
||||
from superset.mcp_service.app import init_fastmcp_server, mcp
|
||||
|
||||
flask_app = _make_flask_app_mock({"health_check"})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.flask_singleton.app",
|
||||
flask_app,
|
||||
),
|
||||
patch.object(mcp.local_provider, "remove_tool") as mock_remove,
|
||||
):
|
||||
init_fastmcp_server()
|
||||
|
||||
removed = {call.args[0] for call in mock_remove.call_args_list}
|
||||
assert "health_check" in removed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_default_instructions disabled_tools filtering tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_disabled_tools_absent_from_instructions() -> None:
|
||||
"""Tools in disabled_tools must not appear as bullet lines in instructions."""
|
||||
instructions = get_default_instructions(
|
||||
disabled_tools={"execute_sql", "health_check"}
|
||||
)
|
||||
|
||||
# The bullet-point entries for disabled tools must be gone
|
||||
assert "- execute_sql:" not in instructions
|
||||
assert "- health_check:" not in instructions
|
||||
# Non-disabled tools must still be present
|
||||
assert "- list_charts:" in instructions
|
||||
assert "- list_dashboards:" in instructions
|
||||
|
||||
|
||||
def test_disabling_get_instance_info_removes_all_prose_references() -> None:
|
||||
"""Disabling get_instance_info must remove ALL prose references to it,
|
||||
not only the bullet-point entry in the Available tools section."""
|
||||
instructions = get_default_instructions(disabled_tools={"get_instance_info"})
|
||||
|
||||
# Bullet entry must be gone
|
||||
assert "- get_instance_info:" not in instructions
|
||||
# Prose directives that instruct the LLM to call the tool must also be gone
|
||||
assert "start with get_instance_info" not in instructions
|
||||
assert "call get_instance_info" not in instructions
|
||||
assert "check their accessible_menus in" not in instructions
|
||||
assert "Feature Availability" not in instructions
|
||||
# Instructions for other tools must be unaffected
|
||||
assert "- list_charts:" in instructions
|
||||
assert "- execute_sql:" in instructions
|
||||
|
||||
|
||||
def test_disabling_execute_sql_removes_all_prose_references() -> None:
|
||||
"""Disabling execute_sql must remove all workflow and example lines that
|
||||
mention it, not only the bullet-point entry."""
|
||||
instructions = get_default_instructions(disabled_tools={"execute_sql"})
|
||||
|
||||
# Bullet entry must be gone
|
||||
assert "- execute_sql:" not in instructions
|
||||
# Workflow steps and request wrapper examples must also be gone
|
||||
assert "execute_sql(" not in instructions
|
||||
assert "execute_sql" not in instructions
|
||||
# Instructions for unrelated tools must be unaffected
|
||||
assert "- list_charts:" in instructions
|
||||
assert "- get_instance_info:" in instructions
|
||||
|
||||
|
||||
def test_no_disabled_tools_returns_full_instructions() -> None:
|
||||
"""Passing no disabled_tools (or empty set) returns the full instructions."""
|
||||
full = get_default_instructions()
|
||||
also_full = get_default_instructions(disabled_tools=set())
|
||||
|
||||
assert "- execute_sql:" in full
|
||||
assert "- health_check:" in full
|
||||
assert full == also_full
|
||||
|
||||
|
||||
def test_instructions_generated_after_disabled_tools_removed() -> None:
|
||||
"""init_fastmcp_server generates instructions AFTER removing disabled tools,
|
||||
so the instructions never advertise tools that clients cannot call."""
|
||||
flask_app = _make_flask_app_mock({"execute_sql"})
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
def fake_get_instructions(
|
||||
branding: str = "Apache Superset",
|
||||
disabled_tools: set[str] | None = None,
|
||||
) -> str:
|
||||
captured.append(str(disabled_tools))
|
||||
return f"instructions for {branding}"
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.flask_singleton.app", flask_app),
|
||||
patch.object(mcp.local_provider, "remove_tool"),
|
||||
patch(
|
||||
"superset.mcp_service.app.get_default_instructions",
|
||||
fake_get_instructions,
|
||||
),
|
||||
):
|
||||
init_fastmcp_server()
|
||||
|
||||
# get_default_instructions must have been called with the disabled set
|
||||
assert len(captured) == 1
|
||||
assert "execute_sql" in captured[0]
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
@@ -261,21 +262,6 @@ def test_table_column_database() -> None:
|
||||
assert TableColumn(database=database).database is database
|
||||
|
||||
|
||||
def test_get_prequeries(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Tests for ``get_prequeries``.
|
||||
"""
|
||||
mocker.patch.object(Database, "get_sqla_engine")
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
|
||||
|
||||
database = Database(database_name="db")
|
||||
with database.get_raw_connection() as conn:
|
||||
conn.cursor().execute.assert_has_calls(
|
||||
[mocker.call("set a=1"), mocker.call("set b=2")]
|
||||
)
|
||||
|
||||
|
||||
def test_catalog_cache() -> None:
|
||||
"""
|
||||
Test the catalog cache.
|
||||
@@ -634,6 +620,142 @@ def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None
|
||||
)
|
||||
|
||||
|
||||
def test_get_sqla_engine_registers_prequery_event_listener(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_sqla_engine registers a connect event listener for prequeries.
|
||||
|
||||
Engines returned by get_sqla_engine must automatically execute prequeries
|
||||
(e.g. SET search_path) on every new connection, so that callers don't need
|
||||
to remember to call get_prequeries() themselves.
|
||||
"""
|
||||
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ['SET search_path = "my_schema"']
|
||||
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_sqla_engine(catalog="my_catalog", schema="my_schema"):
|
||||
pass
|
||||
|
||||
db_engine_spec.get_prequeries.assert_called_once_with(
|
||||
database=database,
|
||||
catalog="my_catalog",
|
||||
schema="my_schema",
|
||||
)
|
||||
event_listen.assert_called_once_with(mock_engine, "connect", mocker.ANY)
|
||||
|
||||
# Call the captured closure directly to verify cursor create → execute → close.
|
||||
captured_fn = event_listen.call_args[0][2]
|
||||
mock_dbapi_conn = mocker.MagicMock()
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_dbapi_conn.cursor.return_value = mock_cursor
|
||||
captured_fn(mock_dbapi_conn, None)
|
||||
mock_cursor.execute.assert_called_once_with('SET search_path = "my_schema"')
|
||||
mock_cursor.close.assert_called_once()
|
||||
|
||||
|
||||
def test_get_sqla_engine_prequery_cursor_closed_on_exception(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that the cursor is always closed even when a prequery raises.
|
||||
"""
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ['SET search_path = "bad_schema"']
|
||||
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_sqla_engine(catalog=None, schema="bad_schema"):
|
||||
pass
|
||||
|
||||
captured_fn = event_listen.call_args[0][2]
|
||||
mock_dbapi_conn = mocker.MagicMock()
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_cursor.execute.side_effect = Exception("invalid schema")
|
||||
mock_dbapi_conn.cursor.return_value = mock_cursor
|
||||
|
||||
with pytest.raises(Exception, match="invalid schema"):
|
||||
captured_fn(mock_dbapi_conn, None)
|
||||
|
||||
mock_cursor.close.assert_called_once()
|
||||
|
||||
|
||||
def test_get_sqla_engine_no_prequeries_no_event_listener(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_sqla_engine does not register an event listener when there
|
||||
are no prequeries.
|
||||
"""
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = []
|
||||
event_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_sqla_engine(catalog=None, schema=None):
|
||||
pass
|
||||
|
||||
event_listen.assert_not_called()
|
||||
|
||||
|
||||
def test_get_raw_connection_executes_prequeries_exactly_once(
|
||||
app_context: None,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that get_raw_connection() runs prequeries exactly once through the
|
||||
connect event listener registered by get_sqla_engine().
|
||||
|
||||
Previously get_raw_connection() had its own manual prequery loop AND
|
||||
called get_sqla_engine() (which registers the listener), so prequeries
|
||||
ran twice. After removing the manual loop the listener is the sole
|
||||
execution point — this test proves exactly-once semantics.
|
||||
"""
|
||||
mock_engine = mocker.MagicMock()
|
||||
mocker.patch.object(Database, "_get_sqla_engine", return_value=mock_engine)
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
prequery = 'SET search_path = "my_schema"'
|
||||
db_engine_spec.get_prequeries.return_value = [prequery]
|
||||
|
||||
# Capture the closure registered via sqla.event.listen.
|
||||
captured_listeners: list[Callable[..., None]] = []
|
||||
original_listen = mocker.patch("superset.models.core.sqla.event.listen")
|
||||
original_listen.side_effect = lambda engine, event, fn: captured_listeners.append(
|
||||
fn
|
||||
)
|
||||
|
||||
# Simulate SQLAlchemy firing the "connect" event when raw_connection() is called.
|
||||
mock_dbapi_conn = mocker.MagicMock()
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_dbapi_conn.cursor.return_value = mock_cursor
|
||||
|
||||
def raw_connection_side_effect() -> Any:
|
||||
for listener in captured_listeners:
|
||||
listener(mock_dbapi_conn, None)
|
||||
return mock_dbapi_conn
|
||||
|
||||
mock_engine.raw_connection.side_effect = raw_connection_side_effect
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="postgresql://")
|
||||
with database.get_raw_connection(schema="my_schema"):
|
||||
pass
|
||||
|
||||
# Exactly one prequery, exactly once — not twice, not zero.
|
||||
mock_cursor.execute.assert_called_once_with(prequery)
|
||||
mock_cursor.close.assert_called_once()
|
||||
|
||||
|
||||
def test_is_oauth2_enabled() -> None:
|
||||
"""
|
||||
Test the `is_oauth2_enabled` method.
|
||||
|
||||
@@ -184,3 +184,120 @@ def test_prophet_incorrect_time_grain():
|
||||
periods=10,
|
||||
confidence_interval=0.8,
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_uncertainty_lower_bound_can_be_negative_for_negative_series():
|
||||
"""
|
||||
Regression for #21734: when the input series contains negative values,
|
||||
the forecast's lower confidence bound (``__yhat_lower``) must be allowed
|
||||
to go below zero. The original bug claimed Superset clipped the lower
|
||||
bound at 0, hiding the natural shape of the uncertainty interval for
|
||||
series like temperatures or signed deltas.
|
||||
|
||||
Superset's wrapper passes through Prophet's output unchanged (no
|
||||
clipping in ``superset/utils/pandas_postprocessing/prophet.py``); this
|
||||
test pins that contract end-to-end. If a future refactor introduces
|
||||
a ``max(0, lower)`` clamp, this test fails immediately.
|
||||
"""
|
||||
if find_spec("prophet") is None:
|
||||
pytest.skip("prophet not installed")
|
||||
|
||||
# All-negative monthly series — any reasonable forecast must predict
|
||||
# negative values (and therefore negative uncertainty bounds) too.
|
||||
negative_df = pd.DataFrame(
|
||||
{
|
||||
DTTM_ALIAS: [datetime(2020, m, 1) for m in range(1, 13)]
|
||||
+ [datetime(2021, m, 1) for m in range(1, 13)],
|
||||
"temperature": [
|
||||
-5.0,
|
||||
-7.0,
|
||||
-3.0,
|
||||
1.0,
|
||||
8.0,
|
||||
14.0,
|
||||
17.0,
|
||||
16.0,
|
||||
11.0,
|
||||
5.0,
|
||||
-1.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
-8.0,
|
||||
-2.0,
|
||||
2.0,
|
||||
9.0,
|
||||
15.0,
|
||||
18.0,
|
||||
17.0,
|
||||
12.0,
|
||||
6.0,
|
||||
0.0,
|
||||
-3.0,
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
result = prophet(
|
||||
df=negative_df,
|
||||
time_grain="P1M",
|
||||
periods=3,
|
||||
confidence_interval=0.9,
|
||||
)
|
||||
|
||||
assert "temperature__yhat_lower" in result.columns
|
||||
# Restrict to the forecast horizon (the last `periods` rows). The full
|
||||
# output also contains historical fitted points, which can be negative
|
||||
# for in-sample data even if a future-only clamp were introduced — so
|
||||
# asserting on the whole frame would let a future-only clamp slip past.
|
||||
forecast_periods = 3
|
||||
forecast_lower = result["temperature__yhat_lower"].iloc[-forecast_periods:]
|
||||
assert (forecast_lower < 0).any(), (
|
||||
"Forecast (future) lower bound was non-negative everywhere despite "
|
||||
"a series with negative actuals — suggests an unexpected clamp at "
|
||||
"zero was reintroduced (regression of #21734)."
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_does_not_clamp_yhat_below_zero_for_negative_actuals():
|
||||
"""
|
||||
Companion to the lower-bound test above: the central forecast
|
||||
(``__yhat``) must also be allowed to go negative.
|
||||
A bug that clamps the central forecast at zero would force the lower
|
||||
bound non-negative as a side effect, masking the wider issue.
|
||||
"""
|
||||
if find_spec("prophet") is None:
|
||||
pytest.skip("prophet not installed")
|
||||
|
||||
negative_df = pd.DataFrame(
|
||||
{
|
||||
DTTM_ALIAS: [datetime(2020, m, 1) for m in range(1, 13)],
|
||||
"balance": [
|
||||
-100.0,
|
||||
-110.0,
|
||||
-95.0,
|
||||
-120.0,
|
||||
-130.0,
|
||||
-125.0,
|
||||
-140.0,
|
||||
-135.0,
|
||||
-150.0,
|
||||
-145.0,
|
||||
-160.0,
|
||||
-155.0,
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
result = prophet(
|
||||
df=negative_df,
|
||||
time_grain="P1M",
|
||||
periods=2,
|
||||
confidence_interval=0.8,
|
||||
)
|
||||
|
||||
# Restrict to the forecast horizon — see lower-bound test above for the
|
||||
# rationale. A future-only clamp on `__yhat` could leave historical
|
||||
# in-sample fitted points negative and pass an unrestricted assertion.
|
||||
forecast_periods = 2
|
||||
forecast_yhat = result["balance__yhat"].iloc[-forecast_periods:]
|
||||
assert (forecast_yhat < 0).any()
|
||||
|
||||
@@ -131,10 +131,12 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
|
||||
"Token revoked"
|
||||
)
|
||||
token = mocker.MagicMock()
|
||||
token.access_token = None
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with pytest.raises(OAuth2ExceptionError):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
db.session.delete.assert_called_with(token)
|
||||
db.session.flush.assert_called_once()
|
||||
@@ -160,10 +162,12 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
|
||||
db_engine_spec.oauth2_exception = OAuth2ExceptionError
|
||||
db_engine_spec.get_oauth2_fresh_token.side_effect = Exception("Network error")
|
||||
token = mocker.MagicMock()
|
||||
token.access_token = None
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with pytest.raises(Exception, match="Network error"):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
db.session.delete.assert_not_called()
|
||||
|
||||
@@ -176,16 +180,18 @@ def test_refresh_oauth2_token_no_access_token_in_response(
|
||||
|
||||
This can happen when the refresh token was revoked.
|
||||
"""
|
||||
mocker.patch("superset.utils.oauth2.db")
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.get_oauth2_fresh_token.return_value = {
|
||||
"error": "invalid_grant",
|
||||
}
|
||||
token = mocker.MagicMock()
|
||||
token.access_token = None
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -208,10 +214,12 @@ def test_refresh_oauth2_token_updates_refresh_token(
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
token = mocker.MagicMock()
|
||||
token.access_token = None
|
||||
token.refresh_token = "old-refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with freeze_time("2024-01-01"):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert token.access_token == "new-access-token" # noqa: S105
|
||||
assert token.access_token_expiration == datetime(2024, 1, 1, 1)
|
||||
@@ -236,16 +244,127 @@ def test_refresh_oauth2_token_keeps_refresh_token(
|
||||
"expires_in": 3600,
|
||||
}
|
||||
token = mocker.MagicMock()
|
||||
token.access_token = None
|
||||
token.refresh_token = "original-refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with freeze_time("2024-01-01"):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert token.access_token == "new-access-token" # noqa: S105
|
||||
assert token.refresh_token == "original-refresh-token" # noqa: S105
|
||||
db.session.add.assert_called_with(token)
|
||||
|
||||
|
||||
def test_refresh_oauth2_token_refreshes_when_access_token_expired_under_lock(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that refresh_oauth2_token triggers a refresh when the access_token is expired.
|
||||
|
||||
When the re-query under the lock returns a token whose access_token has expired
|
||||
but a refresh_token is available, the function should call the token endpoint
|
||||
and persist the new access_token.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.get_oauth2_fresh_token.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
token = mocker.MagicMock()
|
||||
token.access_token = "expired-token" # noqa: S105
|
||||
token.access_token_expiration = datetime(2024, 1, 1)
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with freeze_time("2024-01-02"):
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert result == "new-access-token"
|
||||
db_engine_spec.get_oauth2_fresh_token.assert_called_once_with(
|
||||
DUMMY_OAUTH2_CONFIG, "refresh-token"
|
||||
)
|
||||
db.session.add.assert_called_with(token)
|
||||
|
||||
|
||||
def test_refresh_oauth2_token_returns_existing_token_when_still_valid_under_lock(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that refresh_oauth2_token returns the existing access_token if still valid.
|
||||
|
||||
When concurrent requests are triggered and the first one refreshes the token and
|
||||
releases the lock before the second one gets to `refresh_oauth2_token`, the second
|
||||
request should pick up the already-refreshed access_token instead of refreshing
|
||||
it again.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
token = mocker.MagicMock()
|
||||
token.access_token = "fresh-access-token" # noqa: S105
|
||||
token.access_token_expiration = datetime(2024, 1, 2)
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with freeze_time("2024-01-01"):
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert result == "fresh-access-token"
|
||||
db_engine_spec.get_oauth2_fresh_token.assert_not_called()
|
||||
db.session.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_refresh_oauth2_token_deletes_when_no_refresh_token_under_lock(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that refresh_oauth2_token deletes the row when there's no refresh_token.
|
||||
|
||||
When the token has expired and the re-query under the lock shows no refresh_token
|
||||
is available, the row should be deleted and None returned so the caller can
|
||||
trigger the OAuth2 dance.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
token = mocker.MagicMock()
|
||||
token.access_token = "expired-token" # noqa: S105
|
||||
token.access_token_expiration = datetime(2024, 1, 1)
|
||||
token.refresh_token = None
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with freeze_time("2024-01-02"):
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert result is None
|
||||
db.session.delete.assert_called_with(token)
|
||||
db_engine_spec.get_oauth2_fresh_token.assert_not_called()
|
||||
|
||||
|
||||
def test_refresh_oauth2_token_returns_none_when_row_deleted_under_lock(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that refresh_oauth2_token returns None when the row is gone under the lock.
|
||||
|
||||
When concurrent requests are triggered and the first one deletes the token row and
|
||||
releases the lock before the second one gets to `refresh_oauth2_token`, the token
|
||||
is queried again to avoid a stale reference.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db.session.query().filter_by().one_or_none.return_value = None
|
||||
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert result is None
|
||||
db_engine_spec.get_oauth2_fresh_token.assert_not_called()
|
||||
|
||||
|
||||
def test_generate_code_verifier_length() -> None:
|
||||
"""
|
||||
Test that generate_code_verifier produces a string of valid length (RFC 7636).
|
||||
|
||||
Reference in New Issue
Block a user