diff --git a/.github/workflows/superset-docs-verify.yml b/.github/workflows/superset-docs-verify.yml index e77e4916664..2862541e32d 100644 --- a/.github/workflows/superset-docs-verify.yml +++ b/.github/workflows/superset-docs-verify.yml @@ -20,7 +20,7 @@ jobs: steps: - uses: actions/checkout@v4 # Do not bump this linkinator-action version without opening - # an ASF Infra ticket to allow the new verison first! + # an ASF Infra ticket to allow the new version first! - uses: JustinBeckwith/linkinator-action@v1.11.0 continue-on-error: true # This will make the job advisory (non-blocking, no red X) with: diff --git a/Dockerfile b/Dockerfile index f2643e5dcc7..da859f6b1fb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,7 +18,7 @@ ###################################################################### # Node stage to deal with static asset construction ###################################################################### -ARG PY_VER=3.11.11-slim-bookworm +ARG PY_VER=3.11.12-slim-bookworm # If BUILDPLATFORM is null, set it to 'amd64' (or leave as is otherwise). ARG BUILDPLATFORM=${BUILDPLATFORM:-amd64} diff --git a/docs/docs/configuration/sql-templating.mdx b/docs/docs/configuration/sql-templating.mdx index dcd5bb0869b..d2b74afa1ab 100644 --- a/docs/docs/configuration/sql-templating.mdx +++ b/docs/docs/configuration/sql-templating.mdx @@ -250,6 +250,14 @@ Will be rendered as: SELECT * FROM users WHERE role IN ('admin', 'viewer') ``` +**Current User RLS Rules** + +The `{{ current_user_rls_rules() }}` macro returns an array of RLS rules applied to the current dataset for the logged in user. + +If you have caching enabled in your Superset configuration, then the list of RLS Rules will be used +by Superset when calculating the cache key. A cache key is a unique identifier that determines if there's a +cache hit in the future and Superset can retrieve cached data. + **Custom URL Parameters** The `{{ url_param('custom_variable') }}` macro lets you define arbitrary URL diff --git a/docs/docs/installation/docker-builds.mdx b/docs/docs/installation/docker-builds.mdx index eaf60a35dc9..c0a268426e2 100644 --- a/docs/docs/installation/docker-builds.mdx +++ b/docs/docs/installation/docker-builds.mdx @@ -96,11 +96,16 @@ RUN . /app/.venv/bin/activate && \ pymssql \ # package needed for using single-sign on authentication: Authlib \ + # openpyxl to be able to upload Excel files + openpyxl \ + # Pillow for Alerts & Reports to generate PDFs of dashboards + Pillow \ # install Playwright for taking screenshots for Alerts & Reports. This assumes the feature flag PLAYWRIGHT_REPORTS_AND_THUMBNAILS is enabled # That feature flag will default to True starting in 6.0.0 # Playwright works only with Chrome. # If you are still using Selenium instead of Playwright, you would instead install here the selenium package and a headless browser & webdriver playwright \ + && playwright install-deps \ && PLAYWRIGHT_BROWSERS_PATH=/usr/local/share/playwright-browsers playwright install chromium # Switch back to the superset user diff --git a/docs/package.json b/docs/package.json index 61304bd1a2d..1fa0b5feb5a 100644 --- a/docs/package.json +++ b/docs/package.json @@ -29,7 +29,7 @@ "antd": "^5.25.1", "docusaurus-plugin-less": "^2.0.2", "less": "^4.3.0", - "less-loader": "^11.0.0", + "less-loader": "^12.3.0", "prism-react-renderer": "^2.4.1", "react": "^18.3.1", "react-dom": "^18.3.1", diff --git a/docs/yarn.lock b/docs/yarn.lock index 217a2242822..095df59e0dc 100644 --- a/docs/yarn.lock +++ b/docs/yarn.lock @@ -8211,10 +8211,10 @@ layout-base@^2.0.0: resolved "https://registry.yarnpkg.com/layout-base/-/layout-base-2.0.1.tgz#d0337913586c90f9c2c075292069f5c2da5dd285" integrity sha512-dp3s92+uNI1hWIpPGH3jK2kxE2lMjdXdr+DH8ynZHpd6PUlH6x6cbuXnoMmiNumznqaNO31xu9e79F0uuZ0JFg== -less-loader@^11.0.0: - version "11.1.4" - resolved "https://registry.npmjs.org/less-loader/-/less-loader-11.1.4.tgz" - integrity sha512-6/GrYaB6QcW6Vj+/9ZPgKKs6G10YZai/l/eJ4SLwbzqNTBsAqt5hSLVF47TgsiBxV1P6eAU0GYRH3YRuQU9V3A== +less-loader@^12.3.0: + version "12.3.0" + resolved "https://registry.yarnpkg.com/less-loader/-/less-loader-12.3.0.tgz#d4a00361568be86a97da3df4f16954b0d4c15340" + integrity sha512-0M6+uYulvYIWs52y0LqN4+QM9TqWAohYSNTo4htE8Z7Cn3G/qQMEmktfHmyJT23k+20kU9zHH2wrfFXkxNLtVw== less@^4.3.0: version "4.3.0" diff --git a/pyproject.toml b/pyproject.toml index 3fadc6b22e3..c9ed5db5a83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "cryptography>=42.0.4, <45.0.0", "deprecation>=2.1.0, <2.2.0", "flask>=2.2.5, <3.0.0", - "flask-appbuilder>=4.6.4, <5.0.0", + "flask-appbuilder>=4.7.0, <5.0.0", "flask-caching>=2.1.0, <3", "flask-compress>=1.13, <2.0", "flask-talisman>=1.0.0, <2.0", diff --git a/requirements/base.txt b/requirements/base.txt index 42074038f5a..c887750e909 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -118,7 +118,7 @@ flask==2.3.3 # flask-session # flask-sqlalchemy # flask-wtf -flask-appbuilder==4.6.4 +flask-appbuilder==4.7.0 # via apache-superset (pyproject.toml) flask-babel==2.0.0 # via flask-appbuilder diff --git a/requirements/development.txt b/requirements/development.txt index d1f69ec6d74..770406029b2 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -202,7 +202,7 @@ flask==2.3.3 # flask-sqlalchemy # flask-testing # flask-wtf -flask-appbuilder==4.6.4 +flask-appbuilder==4.7.0 # via # -c requirements/base.txt # apache-superset diff --git a/superset-frontend/package-lock.json b/superset-frontend/package-lock.json index 3ff13ccdb79..5221d4b1028 100644 --- a/superset-frontend/package-lock.json +++ b/superset-frontend/package-lock.json @@ -216,7 +216,7 @@ "@typescript-eslint/parser": "^5.62.0", "@wojtekmaj/enzyme-adapter-react-17": "^0.8.0", "babel-jest": "^29.7.0", - "babel-loader": "^9.1.3", + "babel-loader": "^10.0.0", "babel-plugin-dynamic-import-node": "^2.3.3", "babel-plugin-jsx-remove-data-test-id": "^3.0.0", "babel-plugin-lodash": "^3.3.4", @@ -15633,142 +15633,20 @@ } }, "node_modules/babel-loader": { - "version": "9.2.1", - "resolved": "https://registry.npmjs.org/babel-loader/-/babel-loader-9.2.1.tgz", - "integrity": "sha512-fqe8naHt46e0yIdkjUZYqddSXfej3AHajX+CSO5X7oy0EmPc6o5Xh+RClNoHjnieWz9AW4kZxW9yyFMhVB1QLA==", + "version": "10.0.0", + "resolved": "https://registry.npmjs.org/babel-loader/-/babel-loader-10.0.0.tgz", + "integrity": "sha512-z8jt+EdS61AMw22nSfoNJAZ0vrtmhPRVi6ghL3rCeRZI8cdNYFiV5xeV3HbE7rlZZNmGH8BVccwWt8/ED0QOHA==", "dev": true, "license": "MIT", "dependencies": { - "find-cache-dir": "^4.0.0", - "schema-utils": "^4.0.0" + "find-up": "^5.0.0" }, "engines": { - "node": ">= 14.15.0" + "node": "^18.20.0 || ^20.10.0 || >=22.0.0" }, "peerDependencies": { "@babel/core": "^7.12.0", - "webpack": ">=5" - } - }, - "node_modules/babel-loader/node_modules/find-cache-dir": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/find-cache-dir/-/find-cache-dir-4.0.0.tgz", - "integrity": "sha512-9ZonPT4ZAK4a+1pUPVPZJapbi7O5qbbJPdYw/NOQWZZbVLdDTYM3A4R9z/DpAM08IDaFGsvPgiGZ82WEwUDWjg==", - "dev": true, - "license": "MIT", - "dependencies": { - "common-path-prefix": "^3.0.0", - "pkg-dir": "^7.0.0" - }, - "engines": { - "node": ">=14.16" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/babel-loader/node_modules/find-up": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/find-up/-/find-up-6.3.0.tgz", - "integrity": "sha512-v2ZsoEuVHYy8ZIlYqwPe/39Cy+cFDzp4dXPaxNvkEuouymu+2Jbz0PxpKarJHYJTmv2HWT3O382qY8l4jMWthw==", - "dev": true, - "license": "MIT", - "dependencies": { - "locate-path": "^7.1.0", - "path-exists": "^5.0.0" - }, - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/babel-loader/node_modules/locate-path": { - "version": "7.2.0", - "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-7.2.0.tgz", - "integrity": "sha512-gvVijfZvn7R+2qyPX8mAuKcFGDf6Nc61GdvGafQsHL0sBIxfKzA+usWn4GFC/bk+QdwPUD4kWFJLhElipq+0VA==", - "dev": true, - "license": "MIT", - "dependencies": { - "p-locate": "^6.0.0" - }, - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/babel-loader/node_modules/p-limit": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-4.0.0.tgz", - "integrity": "sha512-5b0R4txpzjPWVw/cXXUResoD4hb6U/x9BH08L7nw+GN1sezDzPdxeRvpc9c433fZhBan/wusjbCsqwqm4EIBIQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "yocto-queue": "^1.0.0" - }, - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/babel-loader/node_modules/p-locate": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-6.0.0.tgz", - "integrity": "sha512-wPrq66Llhl7/4AGC6I+cqxT07LhXvWL08LNXz1fENOw0Ap4sRZZ/gZpTTJ5jpurzzzfS2W/Ge9BY3LgLjCShcw==", - "dev": true, - "license": "MIT", - "dependencies": { - "p-limit": "^4.0.0" - }, - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/babel-loader/node_modules/path-exists": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-5.0.0.tgz", - "integrity": "sha512-RjhtfwJOxzcFmNOi6ltcbcu4Iu+FL3zEj83dk4kAS+fVpTxXLO1b38RvJgT/0QwvV/L3aY9TAnyv0EOqW4GoMQ==", - "dev": true, - "license": "MIT", - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - } - }, - "node_modules/babel-loader/node_modules/pkg-dir": { - "version": "7.0.0", - "resolved": "https://registry.npmjs.org/pkg-dir/-/pkg-dir-7.0.0.tgz", - "integrity": "sha512-Ie9z/WINcxxLp27BKOCHGde4ITq9UklYKDzVo1nhk5sqGEXU3FpkwP5GM2voTGJkGd9B3Otl+Q4uwSOeSUtOBA==", - "dev": true, - "license": "MIT", - "dependencies": { - "find-up": "^6.3.0" - }, - "engines": { - "node": ">=14.16" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/babel-loader/node_modules/yocto-queue": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-1.1.1.tgz", - "integrity": "sha512-b4JR1PFR10y1mKjhHY9LaGo6tmrgjit7hxVIeAmyMw3jegXR4dhYqLaQF5zMXZxY7tLpMyJeLjr1C4rLmkVe8g==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12.20" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" + "webpack": ">=5.61.0" } }, "node_modules/babel-plugin-dynamic-import-node": { @@ -17986,13 +17864,6 @@ "dev": true, "license": "ISC" }, - "node_modules/common-path-prefix": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/common-path-prefix/-/common-path-prefix-3.0.0.tgz", - "integrity": "sha512-QE33hToZseCH3jS0qN96O/bSh3kaw/h+Tq7ngyY9eWDUnTlTNUyqfqvCXioLe5Na5jFsL78ra/wuBU4iuEgd4w==", - "dev": true, - "license": "ISC" - }, "node_modules/common-tags": { "version": "1.8.2", "resolved": "https://registry.npmjs.org/common-tags/-/common-tags-1.8.2.tgz", @@ -23581,9 +23452,9 @@ } }, "node_modules/fastify": { - "version": "4.29.0", - "resolved": "https://registry.npmjs.org/fastify/-/fastify-4.29.0.tgz", - "integrity": "sha512-MaaUHUGcCgC8fXQDsDtioaCcag1fmPJ9j64vAKunqZF4aSub040ZGi/ag8NGE2714yREPOKZuHCfpPzuUD3UQQ==", + "version": "4.29.1", + "resolved": "https://registry.npmjs.org/fastify/-/fastify-4.29.1.tgz", + "integrity": "sha512-m2kMNHIG92tSNWv+Z3UeTR9AWLLuo7KctC7mlFPtMEVrfjIhmQhkQnT9v15qA/BfVq3vvj134Y0jl9SBje3jXQ==", "dev": true, "funding": [ { @@ -53327,7 +53198,7 @@ "@babel/preset-react": "^7.26.3", "@babel/preset-typescript": "^7.23.3", "@storybook/react-webpack5": "8.2.9", - "babel-loader": "^9.1.3", + "babel-loader": "^10.0.0", "fork-ts-checker-webpack-plugin": "^9.0.2", "ts-loader": "^9.5.2", "typescript": "^5.7.2" diff --git a/superset-frontend/package.json b/superset-frontend/package.json index 75b4add8452..5d6b5912699 100644 --- a/superset-frontend/package.json +++ b/superset-frontend/package.json @@ -284,7 +284,7 @@ "@typescript-eslint/parser": "^5.62.0", "@wojtekmaj/enzyme-adapter-react-17": "^0.8.0", "babel-jest": "^29.7.0", - "babel-loader": "^9.1.3", + "babel-loader": "^10.0.0", "babel-plugin-dynamic-import-node": "^2.3.3", "babel-plugin-jsx-remove-data-test-id": "^3.0.0", "babel-plugin-lodash": "^3.3.4", diff --git a/superset-frontend/packages/superset-ui-chart-controls/src/operators/aggregateOperator.ts b/superset-frontend/packages/superset-ui-chart-controls/src/operators/aggregateOperator.ts index aa3c518ad92..1431fbac2c6 100644 --- a/superset-frontend/packages/superset-ui-chart-controls/src/operators/aggregateOperator.ts +++ b/superset-frontend/packages/superset-ui-chart-controls/src/operators/aggregateOperator.ts @@ -30,7 +30,7 @@ export const aggregationOperator: PostProcessingFactory< > = (formData: QueryFormData, queryObject) => { const { aggregation = 'LAST_VALUE' } = formData; - if (aggregation === 'LAST_VALUE') { + if (aggregation === 'LAST_VALUE' || aggregation === 'raw') { return undefined; } diff --git a/superset-frontend/packages/superset-ui-chart-controls/src/shared-controls/customControls.tsx b/superset-frontend/packages/superset-ui-chart-controls/src/shared-controls/customControls.tsx index 971e2b39fc5..0f626c92390 100644 --- a/superset-frontend/packages/superset-ui-chart-controls/src/shared-controls/customControls.tsx +++ b/superset-frontend/packages/superset-ui-chart-controls/src/shared-controls/customControls.tsx @@ -70,6 +70,7 @@ export const aggregationControl = { clearable: false, renderTrigger: false, choices: [ + ['raw', t('None')], ['LAST_VALUE', t('Last Value')], ['sum', t('Total (Sum)')], ['mean', t('Average (Mean)')], @@ -77,7 +78,9 @@ export const aggregationControl = { ['max', t('Maximum')], ['median', t('Median')], ], - description: t('Select an aggregation method to apply to the metric.'), + description: t( + 'Aggregation method used to compute the Big Number from the Trendline.For non-additive metrics like ratios, averages, distinct counts, etc use NONE.', + ), provideFormDataToProps: true, mapStateToProps: ({ form_data }: ControlPanelState) => ({ value: form_data.aggregation || 'LAST_VALUE', diff --git a/superset-frontend/packages/superset-ui-demo/package.json b/superset-frontend/packages/superset-ui-demo/package.json index c099af9b7e5..a9068821d29 100644 --- a/superset-frontend/packages/superset-ui-demo/package.json +++ b/superset-frontend/packages/superset-ui-demo/package.json @@ -57,7 +57,7 @@ "@babel/preset-react": "^7.26.3", "@babel/preset-typescript": "^7.23.3", "@storybook/react-webpack5": "8.2.9", - "babel-loader": "^9.1.3", + "babel-loader": "^10.0.0", "fork-ts-checker-webpack-plugin": "^9.0.2", "ts-loader": "^9.5.2", "typescript": "^5.7.2" diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.test.ts b/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.test.ts new file mode 100644 index 00000000000..e69d32646dd --- /dev/null +++ b/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.test.ts @@ -0,0 +1,86 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { QueryFormData } from '@superset-ui/core'; +import buildQuery from './buildQuery'; + +jest.mock('@superset-ui/core', () => ({ + ...jest.requireActual('@superset-ui/core'), + getXAxisColumn: jest.fn(() => 'order_date'), + isXAxisSet: jest.fn(() => true), +})); + +jest.mock('@superset-ui/chart-controls', () => ({ + pivotOperator: jest.fn(() => ({ operation: 'pivot' })), + aggregationOperator: jest.fn(formData => { + if (formData.aggregation === 'LAST_VALUE' || !formData.aggregation) { + return undefined; + } + return { + operation: 'aggregation', + options: { operator: formData.aggregation }, + }; + }), + flattenOperator: jest.fn(() => ({ operation: 'flatten' })), + resampleOperator: jest.fn(() => ({ operation: 'resample' })), + rollingWindowOperator: jest.fn(() => ({ operation: 'rolling' })), +})); + +describe('BigNumberWithTrendline buildQuery', () => { + const baseFormData: QueryFormData = { + datasource: '1__table', + viz_type: 'big_number', + metric: 'custom_metric', + aggregation: null, + }; + + it('creates raw metric query when aggregation is null', () => { + const queryContext = buildQuery({ ...baseFormData }); + const bigNumberQuery = queryContext.queries[1]; + + expect(bigNumberQuery.post_processing).toEqual([{ operation: 'pivot' }]); + expect(bigNumberQuery.is_timeseries).toBe(true); + }); + + it('adds aggregation operator when aggregation is "sum"', () => { + const queryContext = buildQuery({ ...baseFormData, aggregation: 'sum' }); + const bigNumberQuery = queryContext.queries[1]; + + expect(bigNumberQuery.post_processing).toEqual([ + { operation: 'pivot' }, + { operation: 'aggregation', options: { operator: 'sum' } }, + ]); + expect(bigNumberQuery.is_timeseries).toBe(true); + }); + + it('skips aggregation when aggregation is LAST_VALUE', () => { + const queryContext = buildQuery({ + ...baseFormData, + aggregation: 'LAST_VALUE', + }); + const bigNumberQuery = queryContext.queries[1]; + + expect(bigNumberQuery.post_processing).toEqual([{ operation: 'pivot' }]); + expect(bigNumberQuery.is_timeseries).toBe(true); + }); + + it('always returns two queries', () => { + const queryContext = buildQuery({ ...baseFormData }); + expect(queryContext.queries.length).toBe(2); + }); +}); diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.ts b/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.ts index 398125719b1..5fb46aa96c5 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.ts @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ + import { buildQueryContext, ensureIsArray, @@ -32,15 +33,17 @@ import { } from '@superset-ui/chart-controls'; export default function buildQuery(formData: QueryFormData) { + const isRawMetric = formData.aggregation === 'raw'; + + const timeColumn = isXAxisSet(formData) + ? ensureIsArray(getXAxisColumn(formData)) + : []; + return buildQueryContext(formData, baseQueryObject => [ { ...baseQueryObject, - columns: [ - ...(isXAxisSet(formData) - ? ensureIsArray(getXAxisColumn(formData)) - : []), - ], - ...(isXAxisSet(formData) ? {} : { is_timeseries: true }), + columns: [...timeColumn], + ...(timeColumn.length ? {} : { is_timeseries: true }), post_processing: [ pivotOperator(formData, baseQueryObject), rollingWindowOperator(formData, baseQueryObject), @@ -48,19 +51,16 @@ export default function buildQuery(formData: QueryFormData) { flattenOperator(formData, baseQueryObject), ], }, - { ...baseQueryObject, - columns: [ - ...(isXAxisSet(formData) - ? ensureIsArray(getXAxisColumn(formData)) - : []), - ], - ...(isXAxisSet(formData) ? {} : { is_timeseries: true }), - post_processing: [ - pivotOperator(formData, baseQueryObject), - aggregationOperator(formData, baseQueryObject), - ], + columns: [...(isRawMetric ? [] : timeColumn)], + is_timeseries: !isRawMetric, + post_processing: isRawMetric + ? [] + : [ + pivotOperator(formData, baseQueryObject), + aggregationOperator(formData, baseQueryObject), + ], }, ]); } diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/MixedTimeseries/transformProps.ts b/superset-frontend/plugins/plugin-chart-echarts/src/MixedTimeseries/transformProps.ts index 7526b820d07..478915cfce1 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/MixedTimeseries/transformProps.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/MixedTimeseries/transformProps.ts @@ -97,6 +97,7 @@ import { getXAxisFormatter, getYAxisFormatter, } from '../utils/formatters'; +import { getMetricDisplayName } from '../utils/metricDisplayName'; const getFormatter = ( customFormatters: Record, @@ -222,6 +223,10 @@ export default function transformProps( } const rebasedDataA = rebaseForecastDatum(data1, verboseMap); + + const MetricDisplayNameA = getMetricDisplayName(metrics[0], verboseMap); + const MetricDisplayNameB = getMetricDisplayName(metricsB[0], verboseMap); + const [rawSeriesA] = extractSeries(rebasedDataA, { fillNeighborValue: stack ? 0 : undefined, xAxis: xAxisLabel, @@ -373,6 +378,12 @@ export default function transformProps( const seriesName = inverted[entryName] || entryName; const colorScaleKey = getOriginalSeries(seriesName, array); + let displayName = `${entryName} (Query A)`; + + if (groupby.length > 0) { + displayName = `${MetricDisplayNameA} (Query A), ${entryName}`; + } + const seriesFormatter = getFormatter( customFormatters, formatter, @@ -382,7 +393,10 @@ export default function transformProps( ); const transformedSeries = transformSeries( - entry, + { + ...entry, + id: `${displayName || ''}`, + }, colorScale, colorScaleKey, { @@ -421,6 +435,12 @@ export default function transformProps( const seriesName = `${seriesEntry} (1)`; const colorScaleKey = getOriginalSeries(seriesEntry, array); + let displayName = `${entryName} (Query B)`; + + if (groupbyB.length > 0) { + displayName = `${MetricDisplayNameB} (Query B), ${entryName}`; + } + const seriesFormatter = getFormatter( customFormattersSecondary, formatterSecondary, @@ -430,7 +450,11 @@ export default function transformProps( ); const transformedSeries = transformSeries( - entry, + { + ...entry, + id: `${displayName || ''}`, + }, + colorScale, colorScaleKey, { @@ -444,9 +468,7 @@ export default function transformProps( stackIdSuffix: '\nb', yAxisIndex: yAxisIndexB, filterState, - seriesKey: primarySeries.has(entry.name as string) - ? `${entry.name} (1)` - : entry.name, + seriesKey: entry.name, sliceId, queryIndex: 1, formatter: diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Radar/transformProps.ts b/superset-frontend/plugins/plugin-chart-echarts/src/Radar/transformProps.ts index c991255abf3..690fabc7ad3 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Radar/transformProps.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Radar/transformProps.ts @@ -24,6 +24,7 @@ import { getNumberFormatter, getTimeFormatter, NumberFormatter, + isDefined, } from '@superset-ui/core'; import type { CallbackDataParams } from 'echarts/types/src/util/types'; import type { RadarSeriesDataItemOption } from 'echarts/types/src/chart/radar/RadarSeries'; @@ -35,6 +36,7 @@ import { EchartsRadarFormData, EchartsRadarLabelType, RadarChartTransformedProps, + SeriesNormalizedMap, } from './types'; import { DEFAULT_LEGEND_FORM_DATA, OpacityEnum } from '../constants'; import { @@ -46,18 +48,31 @@ import { import { defaultGrid } from '../defaults'; import { Refs } from '../types'; import { getDefaultTooltip } from '../utils/tooltip'; +import { findGlobalMax, renderNormalizedTooltip } from './utils'; export function formatLabel({ params, labelType, numberFormatter, + getDenormalizedSeriesValue, + metricsWithCustomBounds, + metricLabels, }: { params: CallbackDataParams; labelType: EchartsRadarLabelType; numberFormatter: NumberFormatter; + getDenormalizedSeriesValue: (seriesName: string, value: string) => number; + metricsWithCustomBounds: Set; + metricLabels: string[]; }): string { - const { name = '', value } = params; - const formattedValue = numberFormatter(value as number); + const { name = '', value, dimensionIndex = 0 } = params; + const metricLabel = metricLabels[dimensionIndex]; + + const formattedValue = numberFormatter( + metricsWithCustomBounds.has(metricLabel) + ? (value as number) + : (getDenormalizedSeriesValue(name, String(value)) as number), + ); switch (labelType) { case EchartsRadarLabelType.Value: @@ -85,6 +100,7 @@ export default function transformProps( } = chartProps; const refs: Refs = {}; const { data = [] } = queriesData[0]; + const globalMax = findGlobalMax(data, Object.keys(data[0] || {})); const coltypeMapping = getColtypesMapping(queriesData[0]); const { @@ -111,14 +127,38 @@ export default function transformProps( const { setDataMask = () => {}, onContextMenu } = hooks; const colorFn = CategoricalColorNamespace.getScale(colorScheme as string); const numberFormatter = getNumberFormatter(numberFormat); + const denormalizedSeriesValues: SeriesNormalizedMap = {}; + + const getDenormalizedSeriesValue = ( + seriesName: string, + normalizedValue: string, + ): number => + denormalizedSeriesValues?.[seriesName]?.[normalizedValue] ?? + Number(normalizedValue); + + const metricLabels = metrics.map(getMetricLabel); + + const metricsWithCustomBounds = new Set( + metricLabels.filter(metricLabel => { + const config = columnConfig?.[metricLabel]; + const hasMax = !!isDefined(config?.radarMetricMaxValue); + const hasMin = + isDefined(config?.radarMetricMinValue) && + config?.radarMetricMinValue !== 0; + return hasMax || hasMin; + }), + ); + const formatter = (params: CallbackDataParams) => formatLabel({ params, numberFormatter, labelType, + getDenormalizedSeriesValue, + metricsWithCustomBounds, + metricLabels, }); - const metricLabels = metrics.map(getMetricLabel); const groupbyLabels = groupby.map(getColumnLabel); const metricLabelAndMaxValueMap = new Map(); @@ -212,28 +252,58 @@ export default function transformProps( {}, ); + const normalizeArray = (arr: number[], decimals = 10, seriesName: string) => + arr.map((value, index) => { + const metricLabel = metricLabels[index]; + if (metricsWithCustomBounds.has(metricLabel)) { + return value; + } + + const max = Math.max(...arr); + const normalizedValue = Number((value / max).toFixed(decimals)); + + denormalizedSeriesValues[seriesName][String(normalizedValue)] = value; + return normalizedValue; + }); + + // Normalize the transformed data + const normalizedTransformedData = transformedData.map(series => { + if (Array.isArray(series.value)) { + const seriesName = String(series?.name || ''); + denormalizedSeriesValues[seriesName] = {}; + + return { + ...series, + value: normalizeArray(series.value as number[], 10, seriesName), + }; + } + return series; + }); + const indicator = metricLabels.map(metricLabel => { + const isMetricWithCustomBounds = metricsWithCustomBounds.has(metricLabel); + if (!isMetricWithCustomBounds) { + return { + name: metricLabel, + max: 1, + min: 0, + }; + } const maxValueInControl = columnConfig?.[metricLabel]?.radarMetricMaxValue; const minValueInControl = columnConfig?.[metricLabel]?.radarMetricMinValue; // Ensure that 0 is at the center of the polar coordinates - const metricValueAsMax = + const maxValue = metricLabelAndMaxValueMap.get(metricLabel) === 0 ? Number.MAX_SAFE_INTEGER - : metricLabelAndMaxValueMap.get(metricLabel); - const max = - maxValueInControl === null ? metricValueAsMax : maxValueInControl; + : globalMax; + const max = isDefined(maxValueInControl) ? maxValueInControl : maxValue; let min: number; - // If the min value doesn't exist, set it to 0 (default), - // if it is null, set it to the min value of the data, - // otherwise, use the value from the control - if (minValueInControl === undefined) { - min = 0; - } else if (minValueInControl === null) { - min = metricLabelAndMinValueMap.get(metricLabel) || 0; - } else { + if (isDefined(minValueInControl)) { min = minValueInControl; + } else { + min = 0; } return { @@ -254,10 +324,24 @@ export default function transformProps( fontWeight: 'bold', }, }, - data: transformedData, + data: normalizedTransformedData, }, ]; + const NormalizedTooltipFormater = ( + params: CallbackDataParams & { + color: string; + name: string; + value: number[]; + }, + ) => + renderNormalizedTooltip( + params, + metricLabels, + getDenormalizedSeriesValue, + metricsWithCustomBounds, + ); + const echartOptions: EChartsCoreOption = { grid: { ...defaultGrid, @@ -266,6 +350,7 @@ export default function transformProps( ...getDefaultTooltip(refs), show: !inContextMenu, trigger: 'item', + formatter: NormalizedTooltipFormater, }, legend: { ...getLegendProps(legendType, legendOrientation, showLegend, theme), diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Radar/types.ts b/superset-frontend/plugins/plugin-chart-echarts/src/Radar/types.ts index 19812012bba..0f335683fe3 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Radar/types.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Radar/types.ts @@ -35,7 +35,7 @@ import { DEFAULT_LEGEND_FORM_DATA } from '../constants'; type RadarColumnConfig = Record< string, - { radarMetricMaxValue?: number; radarMetricMinValue?: number } + { radarMetricMaxValue?: number | null; radarMetricMinValue?: number } >; export type EchartsRadarFormData = QueryFormData & @@ -53,6 +53,7 @@ export type EchartsRadarFormData = QueryFormData & isCircle: boolean; numberFormat: string; dateFormat: string; + isNormalized: boolean; }; export enum EchartsRadarLabelType { @@ -83,3 +84,17 @@ export type RadarChartTransformedProps = BaseTransformedProps & ContextMenuTransformedProps & CrossFilterTransformedProps; + +/** + * Represents a mapping from a normalized value (as string) to an original numeric value. + */ +interface NormalizedValueMap { + [normalized: string]: number; +} + +/** + * Represents a collection of series, each containing its own NormalizedValueMap. + */ +export interface SeriesNormalizedMap { + [seriesName: string]: NormalizedValueMap; +} diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Radar/utils.ts b/superset-frontend/plugins/plugin-chart-echarts/src/Radar/utils.ts new file mode 100644 index 00000000000..343d9bbd392 --- /dev/null +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Radar/utils.ts @@ -0,0 +1,92 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + function for finding the max metric values among all series data for Radar Chart +*/ +export const findGlobalMax = ( + data: Record[], + metrics: string[], +): number => { + if (!data?.length || !metrics?.length) return 0; + + return data.reduce((globalMax, row) => { + const rowMax = metrics.reduce((max, metric) => { + const value = row[metric]; + return typeof value === 'number' && + Number.isFinite(value) && + !Number.isNaN(value) + ? Math.max(max, value) + : max; + }, 0); + + return Math.max(globalMax, rowMax); + }, 0); +}; + +interface TooltipParams { + color: string; + name?: string; + value: number[]; +} + +interface TooltipMetricValue { + metric: string; + value: number; +} + +export const renderNormalizedTooltip = ( + params: TooltipParams, + metrics: string[], + getDenormalizedValue: (seriesName: string, value: string) => number, + metricsWithCustomBounds: Set, +): string => { + const { color, name = '', value: values } = params; + const seriesName = name || 'series0'; + + const colorDot = ``; + + // Get metric values with denormalization if needed + const metricValues: TooltipMetricValue[] = metrics.map((metric, index) => { + const value = values[index]; + const originalValue = metricsWithCustomBounds.has(metric) + ? value + : getDenormalizedValue(name, String(value)); + + return { + metric, + value: originalValue, + }; + }); + + const tooltipRows = metricValues + .map( + ({ metric, value }) => ` +
+
${colorDot}${metric}:
+
${value}
+
+ `, + ) + .join(''); + + return ` +
${seriesName}
+ ${tooltipRows} + `; +}; diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Sankey/transformProps.ts b/superset-frontend/plugins/plugin-chart-echarts/src/Sankey/transformProps.ts index 3f2c057c8e0..e9f912d1af7 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Sankey/transformProps.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Sankey/transformProps.ts @@ -79,13 +79,25 @@ export default function transformProps( })); // stores a map with the total values for each node considering the links - const nodeValues = new Map(); + const incomingFlows = new Map(); + const outgoingFlows = new Map(); + const allNodeNames = new Set(); + links.forEach(link => { const { source, target, value } = link; - const sourceValue = nodeValues.get(source) || 0; - const targetValue = nodeValues.get(target) || 0; - nodeValues.set(source, sourceValue + value); - nodeValues.set(target, targetValue + value); + allNodeNames.add(source); + allNodeNames.add(target); + incomingFlows.set(target, (incomingFlows.get(target) || 0) + value); + outgoingFlows.set(source, (outgoingFlows.get(source) || 0) + value); + }); + + const nodeValues = new Map(); + + allNodeNames.forEach(nodeName => { + const totalIncoming = incomingFlows.get(nodeName) || 0; + const totalOutgoing = outgoingFlows.get(nodeName) || 0; + + nodeValues.set(nodeName, Math.max(totalIncoming, totalOutgoing)); }); const tooltipFormatter = (params: CallbackDataParams) => { diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/Regular/Bar/controlPanel.tsx b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/Regular/Bar/controlPanel.tsx index 55cd48736a0..3c61cab8093 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/Regular/Bar/controlPanel.tsx +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/Regular/Bar/controlPanel.tsx @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -import { t } from '@superset-ui/core'; +import { JsonArray, t } from '@superset-ui/core'; import { ControlPanelConfig, ControlPanelsContainerProps, @@ -45,6 +45,7 @@ import { DEFAULT_FORM_DATA, TIME_SERIES_DESCRIPTION_TEXT, } from '../../constants'; +import { StackControlsValue } from '../../../constants'; const { logAxis, @@ -321,6 +322,38 @@ const config: ControlPanelConfig = { ['color_scheme'], ['time_shift_color'], ...showValueSection, + [ + { + name: 'stackDimension', + config: { + type: 'SelectControl', + label: t('Split stack by'), + visibility: ({ controls }) => + controls?.stack?.value === StackControlsValue.Stack, + renderTrigger: true, + description: t( + 'Stack in groups, where each group corresponds to a dimension', + ), + shouldMapStateToProps: ( + prevState, + state, + controlState, + chartState, + ) => true, + mapStateToProps: (state, controlState, chartState) => { + const value: JsonArray = state.controls.groupby + .value as JsonArray; + const valueAsStringArr: string[][] = value.map(v => { + if (v) return [v.toString(), v.toString()]; + return ['', '']; + }); + return { + choices: valueAsStringArr, + }; + }, + }, + }, + ], [minorTicks], [ { diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts index 4bf91f8cf36..680000fa9f6 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/transformProps.ts @@ -191,6 +191,7 @@ export default function transformProps( yAxisTitleMargin, yAxisTitlePosition, zoomable, + stackDimension, }: EchartsTimeseriesFormData = { ...DEFAULT_FORM_DATA, ...formData }; const refs: Refs = {}; const groupBy = ensureIsArray(groupby); @@ -418,6 +419,23 @@ export default function transformProps( } }); + if ( + stack === StackControlsValue.Stack && + stackDimension && + chartProps.rawFormData.groupby + ) { + const idxSelectedDimension = + formData.metrics.length > 1 + ? 1 + : 0 + chartProps.rawFormData.groupby.indexOf(stackDimension); + for (const s of series) { + if (s.id) { + const columnsArr = labelMap[s.id]; + (s as any).stack = columnsArr[idxSelectedDimension]; + } + } + } + // axis bounds need to be parsed to replace incompatible values with undefined const [xAxisMin, xAxisMax] = (xAxisBounds || []).map(parseAxisBound); let [yAxisMin, yAxisMax] = (yAxisBounds || []).map(parseAxisBound); diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/types.ts b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/types.ts index 88a55b46be4..bdcb736956c 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/types.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/types.ts @@ -74,6 +74,7 @@ export type EchartsTimeseriesFormData = QueryFormData & { rowLimit: number; seriesType: EchartsTimeseriesSeriesType; stack: StackType; + stackDimension: string; timeCompare?: string[]; tooltipTimeFormat?: string; showTooltipTotal?: boolean; diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/utils/metricDisplayName.ts b/superset-frontend/plugins/plugin-chart-echarts/src/utils/metricDisplayName.ts new file mode 100644 index 00000000000..4f85836f5b2 --- /dev/null +++ b/superset-frontend/plugins/plugin-chart-echarts/src/utils/metricDisplayName.ts @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { QueryFormMetric } from '@superset-ui/core'; + +export const getMetricDisplayName = ( + metric: QueryFormMetric, + verboseMap: Record = {}, +): string => { + // Case 1: Simple string metric - use verboseMap or the string itself + if (typeof metric === 'string') { + return verboseMap[metric] || metric; + } + + // Case 2: Metric with explicit label - always prefer this if available + if (metric.label) { + return metric.label; + } + + // Case 3: SIMPLE expression type (column with aggregate) + if (metric.expressionType === 'SIMPLE') { + const column = metric.column || {}; + const columnName = column.column_name || ''; + // Use verbose name from column if available + const displayName = column.verbose_name || columnName; + const aggregate = metric.aggregate || ''; + + // If the verbose map has this column, use that + if (verboseMap[columnName]) { + return `${aggregate}(${verboseMap[columnName]})`; + } + + return `${aggregate}(${displayName})`; + } + + // Case 4: SQL expression + if (metric.expressionType === 'SQL') { + return metric.sqlExpression || 'Custom SQL Metric'; + } + + // Fallback + return 'Unknown Metric'; +}; diff --git a/superset-frontend/plugins/plugin-chart-echarts/test/MixedTimeseries/transformProps.test.ts b/superset-frontend/plugins/plugin-chart-echarts/test/MixedTimeseries/transformProps.test.ts index 9a24c990c9a..533b1b23b74 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/test/MixedTimeseries/transformProps.test.ts +++ b/superset-frontend/plugins/plugin-chart-echarts/test/MixedTimeseries/transformProps.test.ts @@ -118,7 +118,9 @@ const chartPropsConfig = { it('should transform chart props for viz', () => { const chartProps = new ChartProps(chartPropsConfig); - expect(transformProps(chartProps as EchartsMixedTimeseriesProps)).toEqual( + const transformed = transformProps(chartProps as EchartsMixedTimeseriesProps); + + expect(transformed).toEqual( expect.objectContaining({ echartOptions: expect.objectContaining({ series: expect.arrayContaining([ @@ -127,7 +129,7 @@ it('should transform chart props for viz', () => { [599616000000, 1], [599916000000, 3], ], - id: 'boy', + id: 'sum__num (Query A), boy', stack: 'obs\na', }), expect.objectContaining({ @@ -135,15 +137,16 @@ it('should transform chart props for viz', () => { [599616000000, 2], [599916000000, 4], ], - id: 'girl', + id: 'sum__num (Query A), girl', stack: 'obs\na', }), + // Query B — Bar series expect.objectContaining({ data: [ [599616000000, 1], [599916000000, 3], ], - id: 'boy (1)', + id: 'sum__num (Query B), boy', stack: 'obs\nb', }), expect.objectContaining({ @@ -151,7 +154,7 @@ it('should transform chart props for viz', () => { [599616000000, 2], [599916000000, 4], ], - id: 'girl (1)', + id: 'sum__num (Query B), girl', stack: 'obs\nb', }), ]), diff --git a/superset-frontend/plugins/plugin-chart-echarts/test/Radar/transformProps.test.ts b/superset-frontend/plugins/plugin-chart-echarts/test/Radar/transformProps.test.ts new file mode 100644 index 00000000000..c66e60b1c7e --- /dev/null +++ b/superset-frontend/plugins/plugin-chart-echarts/test/Radar/transformProps.test.ts @@ -0,0 +1,127 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { ChartProps, supersetTheme } from '@superset-ui/core'; +import { RadarSeriesOption } from 'echarts/charts'; +import transformProps from '../../src/Radar/transformProps'; +import { + EchartsRadarChartProps, + EchartsRadarFormData, +} from '../../src/Radar/types'; + +interface RadarIndicator { + name: string; + max: number; + min: number; +} + +type RadarShape = 'circle' | 'polygon'; + +interface RadarChartConfig { + shape: RadarShape; + indicator: RadarIndicator[]; +} + +interface RadarSeriesData { + value: number[]; + name: string; +} + +describe('Radar transformProps', () => { + const formData: Partial = { + colorScheme: 'supersetColors', + datasource: '3__table', + granularity_sqla: 'ds', + columnConfig: { + 'MAX(na_sales)': { + radarMetricMaxValue: null, + radarMetricMinValue: 0, + }, + 'SUM(eu_sales)': { + radarMetricMaxValue: 5000, + }, + }, + groupby: [], + metrics: [ + 'MAX(na_sales)', + 'SUM(jp_sales)', + 'SUM(other_sales)', + 'SUM(eu_sales)', + ], + viz_type: 'radar', + numberFormat: 'SMART_NUMBER', + dateFormat: 'smart_date', + showLegend: true, + showLabels: true, + isCircle: false, + }; + + const chartProps = new ChartProps({ + formData, + width: 800, + height: 600, + queriesData: [ + { + data: [ + { + 'MAX(na_sales)': 41.49, + 'SUM(jp_sales)': 1290.99, + 'SUM(other_sales)': 797.73, + 'SUM(eu_sales)': 2434.13, + }, + ], + }, + ], + theme: supersetTheme, + }); + + it('should transform chart props for normalized radar chart & normalize all metrics except the ones with custom min & max', () => { + const transformedProps = transformProps( + chartProps as EchartsRadarChartProps, + ); + const series = transformedProps.echartOptions.series as RadarSeriesOption[]; + const radar = transformedProps.echartOptions.radar as RadarChartConfig; + + expect((series[0].data as RadarSeriesData[])[0].value).toEqual([ + 0.0170451044, 0.5303701939, 0.3277269497, 2434.13, + ]); + + expect(radar.indicator).toEqual([ + { + name: 'MAX(na_sales)', + max: 1, + min: 0, + }, + { + name: 'SUM(jp_sales)', + max: 1, + min: 0, + }, + { + name: 'SUM(other_sales)', + max: 1, + min: 0, + }, + { + name: 'SUM(eu_sales)', + max: 5000, + min: 0, + }, + ]); + }); +}); diff --git a/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx b/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx index 0bd69e7cc53..1f9425354e5 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx +++ b/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx @@ -343,6 +343,26 @@ const config: ControlPanelConfig = { }, }, ], + [ + { + name: 'order_desc', + config: { + type: 'CheckboxControl', + label: t('Sort descending'), + default: true, + description: t( + 'If enabled, this control sorts the results/values descending, otherwise it sorts the results ascending.', + ), + visibility: ({ controls }: ControlPanelsContainerProps) => { + const hasSortMetric = Boolean( + controls?.timeseries_limit_metric?.value, + ); + return hasSortMetric && isAggMode({ controls }); + }, + resetOnHide: false, + }, + }, + ], [ { name: 'server_pagination', @@ -413,21 +433,6 @@ const config: ControlPanelConfig = { }, }, ], - [ - { - name: 'order_desc', - config: { - type: 'CheckboxControl', - label: t('Sort descending'), - default: true, - description: t( - 'If enabled, this control sorts the results/values descending, otherwise it sorts the results ascending.', - ), - visibility: isAggMode, - resetOnHide: false, - }, - }, - ], [ { name: 'show_totals', diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index 70269c37ec7..82d103ed745 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -1294,7 +1294,8 @@ export function createDatasourceFailed(err) { export function createDatasource(vizOptions) { return dispatch => { dispatch(createDatasourceStarted()); - const { dbId, catalog, schema, datasourceName, sql } = vizOptions; + const { dbId, catalog, schema, datasourceName, sql, templateParams } = + vizOptions; return SupersetClient.post({ endpoint: '/api/v1/dataset/', headers: { 'Content-Type': 'application/json' }, @@ -1306,6 +1307,7 @@ export function createDatasource(vizOptions) { table_name: datasourceName, is_managed_externally: false, external_url: null, + template_params: templateParams, }), }) .then(({ json }) => { diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/SaveDatasetModal.test.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/SaveDatasetModal.test.tsx index b1ebe60016e..7d6effb54f5 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/SaveDatasetModal.test.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/SaveDatasetModal.test.tsx @@ -29,6 +29,7 @@ import fetchMock from 'fetch-mock'; import { SaveDatasetModal } from 'src/SqlLab/components/SaveDatasetModal'; import { createDatasource } from 'src/SqlLab/actions/sqlLab'; import { user, testQuery, mockdatasets } from 'src/SqlLab/fixtures'; +import { FeatureFlag } from '@superset-ui/core'; const mockedProps = { visible: true, @@ -250,4 +251,88 @@ describe('SaveDatasetModal', () => { templateParams: undefined, }); }); + + it('does not renders a checkbox button when template processing is disabled', () => { + render(, { useRedux: true }); + expect(screen.queryByRole('checkbox')).not.toBeInTheDocument(); + }); + + it('renders a checkbox button when template processing is enabled', () => { + // @ts-ignore + global.featureFlags = { + [FeatureFlag.EnableTemplateProcessing]: true, + }; + render(, { useRedux: true }); + expect(screen.getByRole('checkbox')).toBeInTheDocument(); + }); + + it('correctly includes template parameters when template processing is enabled', () => { + // @ts-ignore + global.featureFlags = { + [FeatureFlag.EnableTemplateProcessing]: true, + }; + const propsWithTemplateParam = { + ...mockedProps, + datasource: { + ...testQuery, + templateParams: JSON.stringify({ my_param: 12 }), + }, + }; + render(, { + useRedux: true, + }); + const inputFieldText = screen.getByDisplayValue(/unimportant/i); + fireEvent.change(inputFieldText, { target: { value: 'my dataset' } }); + + userEvent.click(screen.getByRole('checkbox')); + + const saveConfirmationBtn = screen.getByRole('button', { + name: /save/i, + }); + userEvent.click(saveConfirmationBtn); + + expect(createDatasource).toHaveBeenCalledWith({ + datasourceName: 'my dataset', + dbId: 1, + catalog: null, + schema: 'main', + sql: 'SELECT *', + templateParams: JSON.stringify({ my_param: 12 }), + }); + }); + + it('correctly excludes template parameters when template processing is enabled', () => { + // @ts-ignore + global.featureFlags = { + [FeatureFlag.EnableTemplateProcessing]: true, + }; + const propsWithTemplateParam = { + ...mockedProps, + datasource: { + ...testQuery, + templateParams: JSON.stringify({ my_param: 12 }), + }, + }; + render(, { + useRedux: true, + }); + const inputFieldText = screen.getByDisplayValue(/unimportant/i); + fireEvent.change(inputFieldText, { target: { value: 'my dataset' } }); + + userEvent.click(screen.getByRole('checkbox')); + + const saveConfirmationBtn = screen.getByRole('button', { + name: /save/i, + }); + userEvent.click(saveConfirmationBtn); + + expect(createDatasource).toHaveBeenCalledWith({ + datasourceName: 'my dataset', + dbId: 1, + catalog: null, + schema: 'main', + sql: 'SELECT *', + templateParams: undefined, + }); + }); }); diff --git a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx index 22df5b4c6c7..47a0aecfb5c 100644 --- a/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx +++ b/superset-frontend/src/SqlLab/components/SaveDatasetModal/index.tsx @@ -23,6 +23,7 @@ import { Radio, RadioChangeEvent } from '@superset-ui/core/components/Radio'; import { AsyncSelect, Button, + Checkbox, Modal, Input, type SelectValue, @@ -36,6 +37,8 @@ import { QueryResponse, QueryFormData, VizType, + FeatureFlag, + isFeatureEnabled, } from '@superset-ui/core'; import { useSelector, useDispatch } from 'react-redux'; import dayjs from 'dayjs'; @@ -186,6 +189,8 @@ export const SaveDatasetModal = ({ const user = useSelector(state => state.user); const dispatch = useDispatch<(dispatch: any) => Promise>(); + const [includeTemplateParameters, setIncludeTemplateParameters] = + useState(false); const createWindow = (url: string) => { if (openWindow) { @@ -286,14 +291,21 @@ export const SaveDatasetModal = ({ // Remove the special filters entry from the templateParams // before saving the dataset. let templateParams; - if (typeof datasource?.templateParams === 'string') { - const p = JSON.parse(datasource.templateParams); - /* eslint-disable-next-line no-underscore-dangle */ - if (p._filters) { + if ( + typeof datasource?.templateParams === 'string' && + includeTemplateParameters + ) { + try { + const p = JSON.parse(datasource.templateParams); /* eslint-disable-next-line no-underscore-dangle */ - delete p._filters; - // eslint-disable-next-line no-param-reassign + if (p._filters) { + /* eslint-disable-next-line no-underscore-dangle */ + delete p._filters; + } templateParams = JSON.stringify(p); + } catch (e) { + // malformed templateParams, do not include it + templateParams = undefined; } } @@ -363,7 +375,27 @@ export const SaveDatasetModal = ({ title={t('Save or Overwrite Dataset')} onHide={onHide} footer={ - <> +
+ {isFeatureEnabled(FeatureFlag.EnableTemplateProcessing) && ( +
+ + setIncludeTemplateParameters(checked ?? false) + } + /> + + {t('Include Template Parameters')} + +
+ )} {newOrOverwrite === DatasetRadioState.SaveNew && (
} > diff --git a/superset-frontend/src/features/allEntities/AllEntitiesTable.test.tsx b/superset-frontend/src/features/allEntities/AllEntitiesTable.test.tsx index 7143c908879..c55c04e7427 100644 --- a/superset-frontend/src/features/allEntities/AllEntitiesTable.test.tsx +++ b/superset-frontend/src/features/allEntities/AllEntitiesTable.test.tsx @@ -91,12 +91,13 @@ describe('AllEntitiesTable', () => { jest.restoreAllMocks(); }); - it('renders when empty', () => { + it('renders when empty with button to tag if user has perm', () => { render( , { useRouter: true }, ); @@ -108,25 +109,68 @@ describe('AllEntitiesTable', () => { expect(screen.getByText('Add tag to entities')).toBeInTheDocument(); }); - it('renders the correct tags for each object type, excluding the current tag', () => { + it('renders when empty without button to tag if user does not have perm', () => { + render( + , + { useRouter: true }, + ); + + expect( + screen.getByText('No entities have this tag currently assigned'), + ).toBeInTheDocument(); + + expect(screen.queryByText('Add tag to entities')).not.toBeInTheDocument(); + }); + + it('renders the correct tags for each object type', () => { render( , { useRouter: true }, ); + expect(screen.getByText('Dashboards')).toBeInTheDocument(); expect(screen.getByText('Sales Dashboard')).toBeInTheDocument(); expect(screen.getByText('Sales')).toBeInTheDocument(); + expect(screen.getByText('Charts')).toBeInTheDocument(); expect(screen.getByText('Monthly Revenue')).toBeInTheDocument(); expect(screen.getByText('Revenue')).toBeInTheDocument(); + expect(screen.getByText('Queries')).toBeInTheDocument(); expect(screen.getByText('User Engagement')).toBeInTheDocument(); expect(screen.getByText('Engagement')).toBeInTheDocument(); + }); - expect(screen.queryByText('Current Tag')).not.toBeInTheDocument(); + it('Only list asset types that have entities', () => { + const mockObjects = { + dashboard: [], + chart: [mockObjectsWithTags.chart[0]], + query: [], + }; + + render( + , + { useRouter: true }, + ); + + expect(screen.queryByText('Dashboards')).not.toBeInTheDocument(); + expect(screen.getByText('Charts')).toBeInTheDocument(); + expect(screen.getByText('Monthly Revenue')).toBeInTheDocument(); + expect(screen.queryByText('Queries')).not.toBeInTheDocument(); }); }); diff --git a/superset-frontend/src/features/allEntities/AllEntitiesTable.tsx b/superset-frontend/src/features/allEntities/AllEntitiesTable.tsx index 2827ec261e5..d94504b5705 100644 --- a/superset-frontend/src/features/allEntities/AllEntitiesTable.tsx +++ b/superset-frontend/src/features/allEntities/AllEntitiesTable.tsx @@ -24,7 +24,6 @@ import { } from '@superset-ui/core/components/TableView'; import { EmptyState } from '@superset-ui/core/components'; import { FacePile, TagsList, type TagType } from 'src/components'; -import { NumberParam, useQueryParam } from 'use-query-params'; import { TaggedObject, TaggedObjects } from 'src/types/TaggedObject'; import { Typography } from '@superset-ui/core/components/Typography'; @@ -55,20 +54,21 @@ interface AllEntitiesTableProps { search?: string; setShowTagModal: (show: boolean) => void; objects: TaggedObjects; + canEditTag: boolean; } export default function AllEntitiesTable({ search = '', setShowTagModal, objects, + canEditTag, }: AllEntitiesTableProps) { type objectType = 'dashboard' | 'chart' | 'query'; - const [tagId] = useQueryParam('id', NumberParam); - const showListViewObjs = - objects.dashboard.length > 0 || - objects.chart.length > 0 || - objects.query.length > 0; + const showDashboardList = objects.dashboard.length > 0; + const showChartList = objects.chart.length > 0; + const showQueryList = objects.query.length > 0; + const showListViewObjs = showDashboardList || showChartList || showQueryList; const renderTable = (type: objectType) => { const data = objects[type].map((o: TaggedObject) => ({ @@ -107,8 +107,7 @@ export default function AllEntitiesTable({ tags={tags.filter( (tag: TagType) => tag.type !== undefined && - ['TagType.custom', 1].includes(tag.type) && - tag.id !== tagId, + ['TagType.custom', 1].includes(tag.type), )} maxTags={MAX_TAGS_TO_SHOW} /> @@ -139,20 +138,34 @@ export default function AllEntitiesTable({ {showListViewObjs ? ( <> -
{t('Dashboards')}
- {renderTable('dashboard')} -
{t('Charts')}
- {renderTable('chart')} -
{t('Queries')}
- {renderTable('query')} + {showDashboardList && ( + <> +
{t('Dashboards')}
+ {renderTable('dashboard')} + + )} + {showChartList && ( + <> +
{t('Charts')}
+ {renderTable('chart')} + + )} + {showQueryList && ( + <> +
{t('Queries')}
+ {renderTable('query')} + + )} ) : ( setShowTagModal(true)} - buttonText={t('Add tag to entities')} + {...(canEditTag && { + buttonAction: () => setShowTagModal(true), + buttonText: t('Add tag to entities'), + })} /> )}
diff --git a/superset-frontend/src/features/groups/types.ts b/superset-frontend/src/features/groups/types.ts new file mode 100644 index 00000000000..4bfab858bc6 --- /dev/null +++ b/superset-frontend/src/features/groups/types.ts @@ -0,0 +1,31 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +export interface BaseGroupListModalProps { + show: boolean; + onHide: () => void; + onSave: () => void; +} + +export interface FormValues { + name: string; + label?: string; + description?: string; + roles: number[]; + users: { value: number; label: string }[]; +} diff --git a/superset-frontend/src/features/groups/utils.ts b/superset-frontend/src/features/groups/utils.ts new file mode 100644 index 00000000000..89e3dc05ccd --- /dev/null +++ b/superset-frontend/src/features/groups/utils.ts @@ -0,0 +1,74 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { SupersetClient, t } from '@superset-ui/core'; +import rison from 'rison'; +import { FormValues } from './types'; + +export const createGroup = async (values: FormValues) => { + await SupersetClient.post({ + endpoint: '/api/v1/security/groups/', + jsonPayload: { ...values, users: values.users.map(user => user.value) }, + }); +}; + +export const updateGroup = async (groupId: number, values: FormValues) => { + await SupersetClient.put({ + endpoint: `/api/v1/security/groups/${groupId}`, + jsonPayload: { ...values, users: values.users.map(user => user.value) }, + }); +}; + +export const deleteGroup = async (groupId: number) => + SupersetClient.delete({ + endpoint: `/api/v1/security/groups/${groupId}`, + }); + +export const fetchUserOptions = async ( + filterValue: string, + page: number, + pageSize: number, + addDangerToast: (msg: string) => void, +) => { + const query = rison.encode({ + filter: filterValue, + page, + page_size: pageSize, + order_column: 'username', + order_direction: 'asc', + }); + + try { + const response = await SupersetClient.get({ + endpoint: `/api/v1/security/users/?q=${query}`, + }); + + const results = response.json?.result || []; + + return { + data: results.map((user: any) => ({ + value: user.id, + label: user.username, + })), + totalCount: response.json?.count ?? 0, + }; + } catch (error) { + addDangerToast(t('There was an error while fetching users')); + return { data: [], totalCount: 0 }; + } +}; diff --git a/superset-frontend/src/features/tags/tags.ts b/superset-frontend/src/features/tags/tags.ts index 661f601d73b..e51677c5d47 100644 --- a/superset-frontend/src/features/tags/tags.ts +++ b/superset-frontend/src/features/tags/tags.ts @@ -179,20 +179,6 @@ export function addTag( .catch(response => error(response)); } -export function fetchObjects( - { tags = '', types }: { tags: string; types: string | null }, - callback: (json: JsonObject) => void, - error: (response: Response) => void, -) { - let url = `/api/v1/tag/get_objects/?tags=${tags}`; - if (types) { - url += `&types=${types}`; - } - SupersetClient.get({ endpoint: url }) - .then(({ json }) => callback(json.result)) - .catch(response => error(response)); -} - export function fetchObjectsByTagIds( { tagIds = [], types }: { tagIds: number[] | string; types: string | null }, callback: (json: JsonObject) => void, diff --git a/superset-frontend/src/pages/ActionLog/index.tsx b/superset-frontend/src/pages/ActionLog/index.tsx new file mode 100644 index 00000000000..3e6b0bd1a14 --- /dev/null +++ b/superset-frontend/src/pages/ActionLog/index.tsx @@ -0,0 +1,282 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { useMemo } from 'react'; +import { t, css } from '@superset-ui/core'; +import SubMenu, { SubMenuProps } from 'src/features/home/SubMenu'; +import { useListViewResource } from 'src/views/CRUD/hooks'; +import { useToasts } from 'src/components/MessageToasts/withToasts'; +import { ListView, Filters, FilterOperator } from 'src/components/ListView'; +// eslint-disable-next-line no-restricted-imports +import { Typography } from '@superset-ui/core/components'; +import { fetchUserOptions } from 'src/features/groups/utils'; + +export type ActionLogObject = { + user: { + username: string; + }; + action: string; + dttm: string | null; + dashboard_id?: number; + slice_id?: number; + json?: string; + duration_ms?: number; + referrer?: string; +}; + +const PAGE_SIZE = 25; + +function ActionLogList() { + const { addDangerToast, addSuccessToast } = useToasts(); + const initialSort = [{ id: 'dttm', desc: true }]; + const subMenuButtons: SubMenuProps['buttons'] = []; + + const { + state: { + loading, + resourceCount: LogsCount, + resourceCollection: Logs, + bulkSelectEnabled, + }, + fetchData, + refreshData, + toggleBulkSelect, + } = useListViewResource( + 'log', + t('Log'), + addDangerToast, + false, + ); + const filters: Filters = useMemo( + () => [ + { + Header: t('Users'), + key: 'user', + id: 'user', + input: 'select', + operator: FilterOperator.RelationOneMany, + unfilteredLabel: t('All'), + fetchSelects: async (filterValue, page, pageSize) => + fetchUserOptions(filterValue, page, pageSize, addDangerToast), + }, + { + Header: t('Dashboard Id'), + key: 'dashboard_id', + id: 'dashboard_id', + input: 'search', + operator: FilterOperator.Equals, + }, + { + Header: t('Slice Id'), + key: 'slice_id', + id: 'slice_id', + input: 'search', + operator: FilterOperator.Equals, + }, + { + Header: t('Action'), + key: 'action', + id: 'action', + input: 'search', + operator: FilterOperator.Contains, + }, + { + Header: t('JSON'), + key: 'json', + id: 'json', + input: 'search', + operator: FilterOperator.Contains, + }, + { + Header: t('dttm'), + key: 'dttm', + id: 'dttm', + input: 'datetime_range', + operator: FilterOperator.Between, + dateFilterValueType: 'iso', + }, + { + Header: t('Referrer'), + key: 'referrer', + id: 'referrer', + input: 'search', + operator: FilterOperator.Equals, + }, + { + Header: t('Duration Ms'), + key: 'duration_ms', + id: 'duration_ms', + input: 'search', + operator: FilterOperator.Equals, + }, + ], + [], + ); + + const columns = useMemo( + () => [ + { + accessor: 'action', + Header: t('Action'), + Cell: ({ + row: { + original: { action }, + }, + }: any) => {action}, + }, + { + accessor: 'user', + Header: t('User'), + Cell: ({ + row: { + original: { user }, + }, + }: any) => {user?.username}, + }, + + { + accessor: 'duration_ms', + Header: t('Duration Ms'), + + Cell: ({ + row: { + original: { duration_ms }, + }, + }: any) => {duration_ms}, + }, + { + accessor: 'dashboard_id', + Header: t('Dashboard Id'), + hidden: false, + Cell: ({ + row: { + original: { dashboard_id }, + }, + }: any) => {dashboard_id}, + }, + { + accessor: 'slice_id', + Header: t('Slice Id'), + hidden: false, + Cell: ({ + row: { + original: { slice_id }, + }, + }: any) => {slice_id}, + }, + { + accessor: 'json', + Header: t('JSON'), + + Cell: ({ + row: { + original: { json }, + }, + }: any) => ( + + {json} + + ), + }, + + { + accessor: 'referrer', + Header: t('Referrer'), + + Cell: ({ + row: { + original: { referrer }, + }, + }: any) => ( + + {referrer} + + ), + }, + { + accessor: 'dttm', + Header: t('Dttm'), + Cell: ({ + row: { + original: { dttm }, + }, + }: any) => {dttm}, + }, + ], + [], + ); + + const emptyState = { + title: t('No Logs yet'), + image: 'filter-results.svg', + }; + + return ( + <> + + + className="action-log-view" + columns={columns} + count={LogsCount} + data={Logs} + fetchData={fetchData} + filters={filters} + initialSort={initialSort} + loading={loading} + pageSize={PAGE_SIZE} + bulkSelectEnabled={bulkSelectEnabled} + disableBulkSelect={toggleBulkSelect} + addDangerToast={addDangerToast} + addSuccessToast={addSuccessToast} + emptyState={emptyState} + refreshData={refreshData} + /> + + ); +} + +export default ActionLogList; diff --git a/superset-frontend/src/pages/AllEntities/index.tsx b/superset-frontend/src/pages/AllEntities/index.tsx index ff3e47dd4a0..6199b772a30 100644 --- a/superset-frontend/src/pages/AllEntities/index.tsx +++ b/superset-frontend/src/pages/AllEntities/index.tsx @@ -34,6 +34,9 @@ import withToasts, { useToasts } from 'src/components/MessageToasts/withToasts'; import { fetchObjectsByTagIds, fetchSingleTag } from 'src/features/tags/tags'; import getOwnerName from 'src/utils/getOwnerName'; import { TaggedObject, TaggedObjects } from 'src/types/TaggedObject'; +import { findPermission } from 'src/utils/findPermission'; +import { useSelector } from 'react-redux'; +import { RootState } from 'src/dashboard/types'; const additionalItemsStyles = (theme: SupersetTheme) => css` display: flex; @@ -99,6 +102,10 @@ function AllEntities() { query: [], }); + const canEditTag = useSelector((state: RootState) => + findPermission('can_write', 'Tag', state.user?.roles), + ); + const editableTitleProps = { title: tag?.name || '', placeholder: 'testing', @@ -210,14 +217,16 @@ function AllEntities() { } rightPanelAdditionalItems={ <> - + {canEditTag && ( + + )} } menuDropdownProps={{ @@ -231,6 +240,7 @@ function AllEntities() { search={tag?.name || ''} setShowTagModal={setShowTagModal} objects={objects} + canEditTag={canEditTag} /> diff --git a/superset-frontend/src/utils/datasourceUtils.js b/superset-frontend/src/utils/datasourceUtils.js index ef984044d13..144a3ff88b6 100644 --- a/superset-frontend/src/utils/datasourceUtils.js +++ b/superset-frontend/src/utils/datasourceUtils.js @@ -23,4 +23,5 @@ export const getDatasourceAsSaveableDataset = source => ({ sql: source?.sql || '', catalog: source?.catalog, schema: source?.schema, + templateParams: source?.templateParams, }); diff --git a/superset-frontend/src/views/routes.tsx b/superset-frontend/src/views/routes.tsx index 691c69245d8..242314ad079 100644 --- a/superset-frontend/src/views/routes.tsx +++ b/superset-frontend/src/views/routes.tsx @@ -137,6 +137,9 @@ const RolesList = lazy( const UsersList: LazyExoticComponent = lazy( () => import(/* webpackChunkName: "UsersList" */ 'src/pages/UsersList'), ); +const ActionLogList: LazyExoticComponent = lazy( + () => import(/* webpackChunkName: "ActionLogList" */ 'src/pages/ActionLog'), +); const Login = lazy( () => import(/* webpackChunkName: "Login" */ 'src/pages/Login'), @@ -260,6 +263,10 @@ export const routes: Routes = [ path: '/sqllab/', Component: SqlLab, }, + { + path: '/actionlog/list', + Component: ActionLogList, + }, ]; if (isFeatureEnabled(FeatureFlag.TaggingSystem)) { diff --git a/superset-websocket/package-lock.json b/superset-websocket/package-lock.json index 1c637770211..04cdc67d67d 100644 --- a/superset-websocket/package-lock.json +++ b/superset-websocket/package-lock.json @@ -31,8 +31,8 @@ "@types/ws": "^8.5.12", "@typescript-eslint/eslint-plugin": "^8.26.0", "@typescript-eslint/parser": "^8.29.0", - "eslint": "^9.17.0", - "eslint-config-prettier": "^9.1.0", + "eslint": "^9.27.0", + "eslint-config-prettier": "^10.1.5", "eslint-plugin-lodash": "^8.0.0", "globals": "^16.0.0", "jest": "^29.7.0", @@ -747,12 +747,13 @@ } }, "node_modules/@eslint/config-array": { - "version": "0.19.1", - "resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.19.1.tgz", - "integrity": "sha512-fo6Mtm5mWyKjA/Chy1BYTdn5mGJoDNjC7C64ug20ADsRDGrA85bN3uK3MaKbeRkRuuIEAR5N33Jr1pbm411/PA==", + "version": "0.20.0", + "resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.20.0.tgz", + "integrity": "sha512-fxlS1kkIjx8+vy2SjuCB94q3htSNrufYTXubwiBFeaQHbH6Ipi43gFJq2zCMt6PHhImH3Xmr0NksKDvchWlpQQ==", "dev": true, + "license": "Apache-2.0", "dependencies": { - "@eslint/object-schema": "^2.1.5", + "@eslint/object-schema": "^2.1.6", "debug": "^4.3.1", "minimatch": "^3.1.2" }, @@ -760,11 +761,22 @@ "node": "^18.18.0 || ^20.9.0 || >=21.1.0" } }, - "node_modules/@eslint/core": { - "version": "0.9.1", - "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.9.1.tgz", - "integrity": "sha512-GuUdqkyyzQI5RMIWkHhvTWLCyLo1jNK3vzkSyaExH5kHPDHcuL2VOpHjmMY+y3+NC69qAKToBqldTBgYeLSr9Q==", + "node_modules/@eslint/config-helpers": { + "version": "0.2.2", + "resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.2.2.tgz", + "integrity": "sha512-+GPzk8PlG0sPpzdU5ZvIRMPidzAnZDl/s9L+y13iodqvb8leL53bTannOrQ/Im7UkpsmFU5Ily5U60LWixnmLg==", "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/core": { + "version": "0.14.0", + "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.14.0.tgz", + "integrity": "sha512-qIbV0/JZr7iSDjqAc60IqbLdsj9GDt16xQtWD+B78d/HAlvysGdZZ6rpJHGAc2T0FQx1X6thsSPdnoiGKdNtdg==", + "dev": true, + "license": "Apache-2.0", "dependencies": { "@types/json-schema": "^7.0.15" }, @@ -773,10 +785,11 @@ } }, "node_modules/@eslint/eslintrc": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-3.2.0.tgz", - "integrity": "sha512-grOjVNN8P3hjJn/eIETF1wwd12DdnwFDoyceUJLYYdkpbwq3nLi+4fqrTAONx7XDALqlL220wC/RHSC/QTI/0w==", + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-3.3.1.tgz", + "integrity": "sha512-gtF186CXhIl1p4pJNGZw8Yc6RlshoePRvE0X91oPGb3vZ8pM3qOS9W9NGPat9LziaBV7XrJWGylNQXkGcnM3IQ==", "dev": true, + "license": "MIT", "dependencies": { "ajv": "^6.12.4", "debug": "^4.3.2", @@ -799,13 +812,15 @@ "version": "2.0.1", "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", - "dev": true + "dev": true, + "license": "Python-2.0" }, "node_modules/@eslint/eslintrc/node_modules/globals": { "version": "14.0.0", "resolved": "https://registry.npmjs.org/globals/-/globals-14.0.0.tgz", "integrity": "sha512-oahGvuMGQlPw/ivIYBjVSrWAfWLBeku5tpPE2fOPLi+WHffIWbuh2tCjhyQhTBPMf5E9jDEH4FOmTYgYwbKwtQ==", "dev": true, + "license": "MIT", "engines": { "node": ">=18" }, @@ -818,6 +833,7 @@ "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", "dev": true, + "license": "MIT", "dependencies": { "argparse": "^2.0.1" }, @@ -826,30 +842,36 @@ } }, "node_modules/@eslint/js": { - "version": "9.25.1", - "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.25.1.tgz", - "integrity": "sha512-dEIwmjntEx8u3Uvv+kr3PDeeArL8Hw07H9kyYxCjnM9pBjfEhk6uLXSchxxzgiwtRhhzVzqmUSDFBOi1TuZ7qg==", + "version": "9.27.0", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.27.0.tgz", + "integrity": "sha512-G5JD9Tu5HJEu4z2Uo4aHY2sLV64B7CDMXxFzqzjl3NKd6RVzSXNoE80jk7Y0lJkTTkjiIhBAqmlYwjuBY3tvpA==", "dev": true, "license": "MIT", "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://eslint.org/donate" } }, "node_modules/@eslint/object-schema": { - "version": "2.1.5", - "resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.5.tgz", - "integrity": "sha512-o0bhxnL89h5Bae5T318nFoFzGy+YE5i/gGkoPAgkmTVdRKTiv3p8JHevPiPaMwoloKfEiiaHlawCqaZMqRm+XQ==", + "version": "2.1.6", + "resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.6.tgz", + "integrity": "sha512-RBMg5FRL0I0gs51M/guSAj5/e14VQ4tpZnQNWwuDT66P14I43ItmPfIZRhO9fUVIPOAQXU47atlywZ/czoqFPA==", "dev": true, + "license": "Apache-2.0", "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" } }, "node_modules/@eslint/plugin-kit": { - "version": "0.2.4", - "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.2.4.tgz", - "integrity": "sha512-zSkKow6H5Kdm0ZUQUB2kV5JIXqoG0+uH5YADhaEHswm664N9Db8dXSi0nMJpacpMf+MyyglF1vnZohpEg5yUtg==", + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.3.1.tgz", + "integrity": "sha512-0J+zgWxHN+xXONWIyPWKFMgVuJoZuGiIFu8yxk7RJjxkzpGmyja5wRFqZIVtjDVOQpV+Rw0iOAjYPE2eQyjr0w==", "dev": true, + "license": "Apache-2.0", "dependencies": { + "@eslint/core": "^0.14.0", "levn": "^0.4.1" }, "engines": { @@ -905,10 +927,11 @@ } }, "node_modules/@humanwhocodes/retry": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/@humanwhocodes/retry/-/retry-0.4.1.tgz", - "integrity": "sha512-c7hNEllBlenFTHBky65mhq8WD2kbN9Q6gk0bTk8lSBvc554jpXSkST1iePudpt7+A/AQvuHs9EMqjHDXMY1lrA==", + "version": "0.4.3", + "resolved": "https://registry.npmjs.org/@humanwhocodes/retry/-/retry-0.4.3.tgz", + "integrity": "sha512-bV0Tgo9K4hfPCek+aMAn81RppFKv2ySDQeMoSZuvTASywNTnVJCArCZE2FWqpvIatKu7VMRLWlR1EazvVhDyhQ==", "dev": true, + "license": "Apache-2.0", "engines": { "node": ">=18.18" }, @@ -2527,6 +2550,7 @@ "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", "dev": true, + "license": "MIT", "peerDependencies": { "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } @@ -2536,6 +2560,7 @@ "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", "dev": true, + "license": "MIT", "dependencies": { "fast-deep-equal": "^3.1.1", "fast-json-stable-stringify": "^2.0.0", @@ -3166,21 +3191,23 @@ } }, "node_modules/eslint": { - "version": "9.17.0", - "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.17.0.tgz", - "integrity": "sha512-evtlNcpJg+cZLcnVKwsai8fExnqjGPicK7gnUtlNuzu+Fv9bI0aLpND5T44VLQtoMEnI57LoXO9XAkIXwohKrA==", + "version": "9.27.0", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.27.0.tgz", + "integrity": "sha512-ixRawFQuMB9DZ7fjU3iGGganFDp3+45bPOdaRurcFHSXO1e/sYwUX/FtQZpLZJR6SjMoJH8hR2pPEAfDyCoU2Q==", "dev": true, + "license": "MIT", "dependencies": { "@eslint-community/eslint-utils": "^4.2.0", "@eslint-community/regexpp": "^4.12.1", - "@eslint/config-array": "^0.19.0", - "@eslint/core": "^0.9.0", - "@eslint/eslintrc": "^3.2.0", - "@eslint/js": "9.17.0", - "@eslint/plugin-kit": "^0.2.3", + "@eslint/config-array": "^0.20.0", + "@eslint/config-helpers": "^0.2.1", + "@eslint/core": "^0.14.0", + "@eslint/eslintrc": "^3.3.1", + "@eslint/js": "9.27.0", + "@eslint/plugin-kit": "^0.3.1", "@humanfs/node": "^0.16.6", "@humanwhocodes/module-importer": "^1.0.1", - "@humanwhocodes/retry": "^0.4.1", + "@humanwhocodes/retry": "^0.4.2", "@types/estree": "^1.0.6", "@types/json-schema": "^7.0.15", "ajv": "^6.12.4", @@ -3188,7 +3215,7 @@ "cross-spawn": "^7.0.6", "debug": "^4.3.2", "escape-string-regexp": "^4.0.0", - "eslint-scope": "^8.2.0", + "eslint-scope": "^8.3.0", "eslint-visitor-keys": "^4.2.0", "espree": "^10.3.0", "esquery": "^1.5.0", @@ -3225,13 +3252,17 @@ } }, "node_modules/eslint-config-prettier": { - "version": "9.1.0", - "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-9.1.0.tgz", - "integrity": "sha512-NSWl5BFQWEPi1j4TjVNItzYV7dZXZ+wP6I6ZhrBGpChQhZRUaElihE9uRRkcbRnNb76UMKDF3r+WTmNcGPKsqw==", + "version": "10.1.5", + "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-10.1.5.tgz", + "integrity": "sha512-zc1UmCpNltmVY34vuLRV61r1K27sWuX39E+uyUnY8xS2Bex88VV9cugG+UZbRSRGtGyFboj+D8JODyme1plMpw==", "dev": true, + "license": "MIT", "bin": { "eslint-config-prettier": "bin/cli.js" }, + "funding": { + "url": "https://opencollective.com/eslint-config-prettier" + }, "peerDependencies": { "eslint": ">=7.0.0" } @@ -3253,10 +3284,11 @@ } }, "node_modules/eslint-scope": { - "version": "8.2.0", - "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.2.0.tgz", - "integrity": "sha512-PHlWUfG6lvPc3yvP5A4PNyBL1W8fkDUccmI21JUu/+GKZBoH/W5u6usENXUrWFRsyoW5ACUjFGgAFQp5gUlb/A==", + "version": "8.3.0", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.3.0.tgz", + "integrity": "sha512-pUNxi75F8MJ/GdeKtVLSbYg4ZI34J6C0C7sbL4YOp2exGwen7ZsuBqKzUhXd0qMQ362yET3z+uPwKeg/0C2XCQ==", "dev": true, + "license": "BSD-2-Clause", "dependencies": { "esrecurse": "^4.3.0", "estraverse": "^5.2.0" @@ -3280,16 +3312,6 @@ "url": "https://opencollective.com/eslint" } }, - "node_modules/eslint/node_modules/@eslint/js": { - "version": "9.17.0", - "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.17.0.tgz", - "integrity": "sha512-Sxc4hqcs1kTu0iID3kcZDW3JHq2a77HO9P8CP6YEA/FpH3Ll8UXE2r/86Rz9YJLKme39S9vU5OWNjC6Xl0Cr3w==", - "dev": true, - "license": "MIT", - "engines": { - "node": "^18.18.0 || ^20.9.0 || >=21.1.0" - } - }, "node_modules/eslint/node_modules/escape-string-regexp": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", @@ -3409,6 +3431,7 @@ "resolved": "https://registry.npmjs.org/espree/-/espree-10.3.0.tgz", "integrity": "sha512-0QYC8b24HWY8zjRnDTL6RiHfDbAWn63qb4LMj1Z4b076A4une81+z03Kg7l7mn/48PUTqoLptSXez8oknU8Clg==", "dev": true, + "license": "BSD-2-Clause", "dependencies": { "acorn": "^8.14.0", "acorn-jsx": "^5.3.2", @@ -3426,6 +3449,7 @@ "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.0.tgz", "integrity": "sha512-UyLnSehNt62FFhSwjZlHmeokpRK59rcz29j+F1/aDgbkbRTk7wIc9XzdoasMUbRNKDM0qQt/+BJ4BrpFeABemw==", "dev": true, + "license": "Apache-2.0", "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" }, @@ -3464,6 +3488,7 @@ "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", "dev": true, + "license": "BSD-2-Clause", "dependencies": { "estraverse": "^5.2.0" }, @@ -3552,7 +3577,8 @@ "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/fast-glob": { "version": "3.3.2", @@ -3885,10 +3911,11 @@ } }, "node_modules/import-fresh": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", - "integrity": "sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==", + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz", + "integrity": "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==", "dev": true, + "license": "MIT", "dependencies": { "parent-module": "^1.0.0", "resolve-from": "^4.0.0" @@ -3905,6 +3932,7 @@ "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", "dev": true, + "license": "MIT", "engines": { "node": ">=4" } @@ -5407,7 +5435,8 @@ "version": "0.4.1", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/json-stable-stringify-without-jsonify": { "version": "1.0.1", @@ -5826,6 +5855,7 @@ "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", "dev": true, + "license": "MIT", "dependencies": { "callsites": "^3.0.0" }, @@ -5993,6 +6023,7 @@ "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", "dev": true, + "license": "MIT", "engines": { "node": ">=6" } @@ -6812,6 +6843,7 @@ "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", "dev": true, + "license": "BSD-2-Clause", "dependencies": { "punycode": "^2.1.0" } @@ -7567,29 +7599,35 @@ "dev": true }, "@eslint/config-array": { - "version": "0.19.1", - "resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.19.1.tgz", - "integrity": "sha512-fo6Mtm5mWyKjA/Chy1BYTdn5mGJoDNjC7C64ug20ADsRDGrA85bN3uK3MaKbeRkRuuIEAR5N33Jr1pbm411/PA==", + "version": "0.20.0", + "resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.20.0.tgz", + "integrity": "sha512-fxlS1kkIjx8+vy2SjuCB94q3htSNrufYTXubwiBFeaQHbH6Ipi43gFJq2zCMt6PHhImH3Xmr0NksKDvchWlpQQ==", "dev": true, "requires": { - "@eslint/object-schema": "^2.1.5", + "@eslint/object-schema": "^2.1.6", "debug": "^4.3.1", "minimatch": "^3.1.2" } }, + "@eslint/config-helpers": { + "version": "0.2.2", + "resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.2.2.tgz", + "integrity": "sha512-+GPzk8PlG0sPpzdU5ZvIRMPidzAnZDl/s9L+y13iodqvb8leL53bTannOrQ/Im7UkpsmFU5Ily5U60LWixnmLg==", + "dev": true + }, "@eslint/core": { - "version": "0.9.1", - "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.9.1.tgz", - "integrity": "sha512-GuUdqkyyzQI5RMIWkHhvTWLCyLo1jNK3vzkSyaExH5kHPDHcuL2VOpHjmMY+y3+NC69qAKToBqldTBgYeLSr9Q==", + "version": "0.14.0", + "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.14.0.tgz", + "integrity": "sha512-qIbV0/JZr7iSDjqAc60IqbLdsj9GDt16xQtWD+B78d/HAlvysGdZZ6rpJHGAc2T0FQx1X6thsSPdnoiGKdNtdg==", "dev": true, "requires": { "@types/json-schema": "^7.0.15" } }, "@eslint/eslintrc": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-3.2.0.tgz", - "integrity": "sha512-grOjVNN8P3hjJn/eIETF1wwd12DdnwFDoyceUJLYYdkpbwq3nLi+4fqrTAONx7XDALqlL220wC/RHSC/QTI/0w==", + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-3.3.1.tgz", + "integrity": "sha512-gtF186CXhIl1p4pJNGZw8Yc6RlshoePRvE0X91oPGb3vZ8pM3qOS9W9NGPat9LziaBV7XrJWGylNQXkGcnM3IQ==", "dev": true, "requires": { "ajv": "^6.12.4", @@ -7627,23 +7665,24 @@ } }, "@eslint/js": { - "version": "9.25.1", - "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.25.1.tgz", - "integrity": "sha512-dEIwmjntEx8u3Uvv+kr3PDeeArL8Hw07H9kyYxCjnM9pBjfEhk6uLXSchxxzgiwtRhhzVzqmUSDFBOi1TuZ7qg==", + "version": "9.27.0", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.27.0.tgz", + "integrity": "sha512-G5JD9Tu5HJEu4z2Uo4aHY2sLV64B7CDMXxFzqzjl3NKd6RVzSXNoE80jk7Y0lJkTTkjiIhBAqmlYwjuBY3tvpA==", "dev": true }, "@eslint/object-schema": { - "version": "2.1.5", - "resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.5.tgz", - "integrity": "sha512-o0bhxnL89h5Bae5T318nFoFzGy+YE5i/gGkoPAgkmTVdRKTiv3p8JHevPiPaMwoloKfEiiaHlawCqaZMqRm+XQ==", + "version": "2.1.6", + "resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.6.tgz", + "integrity": "sha512-RBMg5FRL0I0gs51M/guSAj5/e14VQ4tpZnQNWwuDT66P14I43ItmPfIZRhO9fUVIPOAQXU47atlywZ/czoqFPA==", "dev": true }, "@eslint/plugin-kit": { - "version": "0.2.4", - "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.2.4.tgz", - "integrity": "sha512-zSkKow6H5Kdm0ZUQUB2kV5JIXqoG0+uH5YADhaEHswm664N9Db8dXSi0nMJpacpMf+MyyglF1vnZohpEg5yUtg==", + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.3.1.tgz", + "integrity": "sha512-0J+zgWxHN+xXONWIyPWKFMgVuJoZuGiIFu8yxk7RJjxkzpGmyja5wRFqZIVtjDVOQpV+Rw0iOAjYPE2eQyjr0w==", "dev": true, "requires": { + "@eslint/core": "^0.14.0", "levn": "^0.4.1" } }, @@ -7678,9 +7717,9 @@ "dev": true }, "@humanwhocodes/retry": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/@humanwhocodes/retry/-/retry-0.4.1.tgz", - "integrity": "sha512-c7hNEllBlenFTHBky65mhq8WD2kbN9Q6gk0bTk8lSBvc554jpXSkST1iePudpt7+A/AQvuHs9EMqjHDXMY1lrA==", + "version": "0.4.3", + "resolved": "https://registry.npmjs.org/@humanwhocodes/retry/-/retry-0.4.3.tgz", + "integrity": "sha512-bV0Tgo9K4hfPCek+aMAn81RppFKv2ySDQeMoSZuvTASywNTnVJCArCZE2FWqpvIatKu7VMRLWlR1EazvVhDyhQ==", "dev": true }, "@istanbuljs/load-nyc-config": { @@ -9359,21 +9398,22 @@ "dev": true }, "eslint": { - "version": "9.17.0", - "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.17.0.tgz", - "integrity": "sha512-evtlNcpJg+cZLcnVKwsai8fExnqjGPicK7gnUtlNuzu+Fv9bI0aLpND5T44VLQtoMEnI57LoXO9XAkIXwohKrA==", + "version": "9.27.0", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.27.0.tgz", + "integrity": "sha512-ixRawFQuMB9DZ7fjU3iGGganFDp3+45bPOdaRurcFHSXO1e/sYwUX/FtQZpLZJR6SjMoJH8hR2pPEAfDyCoU2Q==", "dev": true, "requires": { "@eslint-community/eslint-utils": "^4.2.0", "@eslint-community/regexpp": "^4.12.1", - "@eslint/config-array": "^0.19.0", - "@eslint/core": "^0.9.0", - "@eslint/eslintrc": "^3.2.0", - "@eslint/js": "9.17.0", - "@eslint/plugin-kit": "^0.2.3", + "@eslint/config-array": "^0.20.0", + "@eslint/config-helpers": "^0.2.1", + "@eslint/core": "^0.14.0", + "@eslint/eslintrc": "^3.3.1", + "@eslint/js": "9.27.0", + "@eslint/plugin-kit": "^0.3.1", "@humanfs/node": "^0.16.6", "@humanwhocodes/module-importer": "^1.0.1", - "@humanwhocodes/retry": "^0.4.1", + "@humanwhocodes/retry": "^0.4.2", "@types/estree": "^1.0.6", "@types/json-schema": "^7.0.15", "ajv": "^6.12.4", @@ -9381,7 +9421,7 @@ "cross-spawn": "^7.0.6", "debug": "^4.3.2", "escape-string-regexp": "^4.0.0", - "eslint-scope": "^8.2.0", + "eslint-scope": "^8.3.0", "eslint-visitor-keys": "^4.2.0", "espree": "^10.3.0", "esquery": "^1.5.0", @@ -9400,12 +9440,6 @@ "optionator": "^0.9.3" }, "dependencies": { - "@eslint/js": { - "version": "9.17.0", - "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.17.0.tgz", - "integrity": "sha512-Sxc4hqcs1kTu0iID3kcZDW3JHq2a77HO9P8CP6YEA/FpH3Ll8UXE2r/86Rz9YJLKme39S9vU5OWNjC6Xl0Cr3w==", - "dev": true - }, "escape-string-regexp": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", @@ -9481,9 +9515,9 @@ } }, "eslint-config-prettier": { - "version": "9.1.0", - "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-9.1.0.tgz", - "integrity": "sha512-NSWl5BFQWEPi1j4TjVNItzYV7dZXZ+wP6I6ZhrBGpChQhZRUaElihE9uRRkcbRnNb76UMKDF3r+WTmNcGPKsqw==", + "version": "10.1.5", + "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-10.1.5.tgz", + "integrity": "sha512-zc1UmCpNltmVY34vuLRV61r1K27sWuX39E+uyUnY8xS2Bex88VV9cugG+UZbRSRGtGyFboj+D8JODyme1plMpw==", "dev": true, "requires": {} }, @@ -9497,9 +9531,9 @@ } }, "eslint-scope": { - "version": "8.2.0", - "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.2.0.tgz", - "integrity": "sha512-PHlWUfG6lvPc3yvP5A4PNyBL1W8fkDUccmI21JUu/+GKZBoH/W5u6usENXUrWFRsyoW5ACUjFGgAFQp5gUlb/A==", + "version": "8.3.0", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.3.0.tgz", + "integrity": "sha512-pUNxi75F8MJ/GdeKtVLSbYg4ZI34J6C0C7sbL4YOp2exGwen7ZsuBqKzUhXd0qMQ362yET3z+uPwKeg/0C2XCQ==", "dev": true, "requires": { "esrecurse": "^4.3.0", @@ -9875,9 +9909,9 @@ "dev": true }, "import-fresh": { - "version": "3.3.0", - "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", - "integrity": "sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==", + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz", + "integrity": "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==", "dev": true, "requires": { "parent-module": "^1.0.0", diff --git a/superset-websocket/package.json b/superset-websocket/package.json index db2b2140583..61ffc639598 100644 --- a/superset-websocket/package.json +++ b/superset-websocket/package.json @@ -39,8 +39,8 @@ "@types/ws": "^8.5.12", "@typescript-eslint/eslint-plugin": "^8.26.0", "@typescript-eslint/parser": "^8.29.0", - "eslint": "^9.17.0", - "eslint-config-prettier": "^9.1.0", + "eslint": "^9.27.0", + "eslint-config-prettier": "^10.1.5", "eslint-plugin-lodash": "^8.0.0", "globals": "^16.0.0", "jest": "^29.7.0", diff --git a/superset-websocket/src/index.ts b/superset-websocket/src/index.ts index 5c3f49fc17a..f43c651813b 100644 --- a/superset-websocket/src/index.ts +++ b/superset-websocket/src/index.ts @@ -102,7 +102,7 @@ if (startServer && opts.jwtSecret.length < 32) { if (startServer && opts.jwtSecret.startsWith('CHANGE-ME')) { console.warn( - 'WARNING: it appears you secret in your config.json is insecure', + 'WARNING: it appears your secret in your config.json is insecure', ); console.warn('DO NOT USE IN PRODUCTION'); } diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py index 3c16730d00b..1e5fb8db44d 100644 --- a/superset/commands/database/test_connection.py +++ b/superset/commands/database/test_connection.py @@ -89,7 +89,9 @@ class TestConnectionDatabaseCommand(BaseCommand): self._context = context self._uri = uri - def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches # noqa: C901 + def run( # noqa: C901 + self, + ) -> None: # pylint: disable=too-many-statements,too-many-branches self.validate() ex_str = "" ssh_tunnel = self._properties.get("ssh_tunnel") @@ -188,7 +190,7 @@ class TestConnectionDatabaseCommand(BaseCommand): ) # check for custom errors (wrong username, wrong password, etc) errors = database.db_engine_spec.extract_errors(ex, self._context) - raise SupersetErrorsException(errors) from ex + raise SupersetErrorsException(errors, status=400) from ex except OAuth2RedirectError: raise except SupersetSecurityException as ex: diff --git a/superset/commands/database/validate.py b/superset/commands/database/validate.py index eda5d75bed8..41f5ce3e19d 100644 --- a/superset/commands/database/validate.py +++ b/superset/commands/database/validate.py @@ -125,7 +125,7 @@ class ValidateDatabaseParametersCommand(BaseCommand): "database": url.database, } errors = database.db_engine_spec.extract_errors(ex, context) - raise DatabaseTestConnectionFailedError(errors) from ex + raise DatabaseTestConnectionFailedError(errors, status=400) from ex if not alive: raise DatabaseOfflineError( diff --git a/superset/commands/sql_lab/execute.py b/superset/commands/sql_lab/execute.py index 001d5609db4..20cba40cfc0 100644 --- a/superset/commands/sql_lab/execute.py +++ b/superset/commands/sql_lab/execute.py @@ -92,7 +92,7 @@ class ExecuteSqlCommand(BaseCommand): pass @transaction() - def run( # pylint: disable=too-many-statements,useless-suppression + def run( self, ) -> CommandResult: """Runs arbitrary sql and returns data as json""" diff --git a/superset/commands/sql_lab/export.py b/superset/commands/sql_lab/export.py index bfa73905483..44fdafe5cdb 100644 --- a/superset/commands/sql_lab/export.py +++ b/superset/commands/sql_lab/export.py @@ -27,7 +27,7 @@ from superset.commands.base import BaseCommand from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetErrorException, SupersetSecurityException from superset.models.sql_lab import Query -from superset.sql_parse import ParsedQuery +from superset.sql.parse import SQLScript from superset.sqllab.limiting_factor import LimitingFactor from superset.utils import core as utils, csv from superset.views.utils import _deserialize_results_payload @@ -115,10 +115,9 @@ class SqlResultExportCommand(BaseCommand): limit = None else: sql = self._query.executed_sql - limit = ParsedQuery( - sql, - engine=self._query.database.db_engine_spec.engine, - ).limit + script = SQLScript(sql, self._query.database.db_engine_spec.engine) + # when a query has multiple statements only the last one returns data + limit = script.statements[-1].get_limit_value() if limit is not None and self._query.limiting_factor in { LimitingFactor.QUERY, LimitingFactor.DROPDOWN, diff --git a/superset/config.py b/superset/config.py index 05a2ec63538..6d5fb9fc6b6 100644 --- a/superset/config.py +++ b/superset/config.py @@ -756,7 +756,7 @@ SCREENSHOT_WAIT_FOR_ERROR_MODAL_INVISIBLE = 5 SCREENSHOT_PLAYWRIGHT_WAIT_EVENT = "load" # Default timeout for Playwright browser context for all operations SCREENSHOT_PLAYWRIGHT_DEFAULT_TIMEOUT = int( - timedelta(seconds=30).total_seconds() * 1000 + timedelta(seconds=60).total_seconds() * 1000 ) # --------------------------------------------------- @@ -1290,7 +1290,7 @@ TRACKING_URL_TRANSFORMER = lambda url: url # noqa: E731 DB_POLL_INTERVAL_SECONDS: dict[str, int] = {} # Interval between consecutive polls when using Presto Engine -# See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression # noqa: E501 +# See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # noqa: E501 PRESTO_POLL_INTERVAL = int(timedelta(seconds=1).total_seconds()) # Allow list of custom authentications for each DB engine. diff --git a/superset/daos/tag.py b/superset/daos/tag.py index 1af19bc1df2..aae3c0dacad 100644 --- a/superset/daos/tag.py +++ b/superset/daos/tag.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import logging -from operator import and_ from typing import Any, Optional from flask import g @@ -24,11 +23,11 @@ from sqlalchemy.exc import NoResultFound from superset.commands.tag.exceptions import TagNotFoundError from superset.commands.tag.utils import to_object_type from superset.daos.base import BaseDAO +from superset.daos.chart import ChartDAO +from superset.daos.dashboard import DashboardDAO +from superset.daos.query import SavedQueryDAO from superset.exceptions import MissingUserContextException from superset.extensions import db -from superset.models.dashboard import Dashboard -from superset.models.slice import Slice -from superset.models.sql_lab import SavedQuery from superset.tags.models import ( get_tag, ObjectType, @@ -43,8 +42,6 @@ logger = logging.getLogger(__name__) class TagDAO(BaseDAO[Tag]): - # base_filter = TagAccessFilter - @staticmethod def create_custom_tagged_objects( object_type: ObjectType, object_id: int, tag_names: list[str] @@ -139,6 +136,13 @@ class TagDAO(BaseDAO[Tag]): """ return db.session.query(Tag).filter(Tag.name == name).first() + @staticmethod + def find_by_names(names: list[str]) -> list[Tag]: + """ + returns tags by their names. + """ + return db.session.query(Tag).filter(Tag.name.in_(names)).all() + @staticmethod def find_tagged_object( object_type: ObjectType, object_id: int, tag_id: int @@ -157,111 +161,105 @@ class TagDAO(BaseDAO[Tag]): ) @staticmethod - def get_tagged_objects_by_tag_id( + def get_tagged_objects_by_tag_ids( tag_ids: Optional[list[int]], obj_types: Optional[list[str]] = None ) -> list[dict[str, Any]]: - tags = db.session.query(Tag).filter(Tag.id.in_(tag_ids)).all() - tag_names = [tag.name for tag in tags] - return TagDAO.get_tagged_objects_for_tags(tag_names, obj_types) + results: list[dict[str, Any]] = [] + + query = db.session.query(TaggedObject).filter(TaggedObject.tag_id.in_(tag_ids)) + if obj_types: + query = query.filter( + TaggedObject.object_type.in_( + [ObjectType[obj_type] for obj_type in obj_types] + ) + ) + tagged_objects = query.all() + + # dashboards + if not obj_types or "dashboard" in obj_types: + tagged_dashboards = [ + tagged_object.object_id + for tagged_object in tagged_objects + if tagged_object.object_type == ObjectType.dashboard + ] + if tagged_dashboards: + results.extend( + { + "id": obj.id, + "type": ObjectType.dashboard.name, + "name": obj.dashboard_title, + "url": obj.url, + "changed_on": obj.changed_on, + "created_by": obj.created_by_fk, + "creator": obj.creator(), + "tags": obj.tags, + "owners": obj.owners, + } + for obj in DashboardDAO.find_by_ids(tagged_dashboards) + ) + + # charts + if not obj_types or "chart" in obj_types: + tagged_charts = [ + tagged_object.object_id + for tagged_object in tagged_objects + if tagged_object.object_type == ObjectType.chart + ] + if tagged_charts: + results.extend( + { + "id": obj.id, + "type": ObjectType.chart.name, + "name": obj.slice_name, + "url": obj.url, + "changed_on": obj.changed_on, + "created_by": obj.created_by_fk, + "creator": obj.creator(), + "tags": obj.tags, + "owners": obj.owners, + } + for obj in ChartDAO.find_by_ids(tagged_charts) + ) + + # saved queries + if not obj_types or "query" in obj_types: + tagged_queries = [ + tagged_object.object_id + for tagged_object in tagged_objects + if tagged_object.object_type == ObjectType.query + ] + if tagged_queries: + results.extend( + { + "id": obj.id, + "type": ObjectType.query.name, + "name": obj.label, + "url": obj.url(), + "changed_on": obj.changed_on, + "created_by": obj.created_by_fk, + "creator": obj.creator(), + "tags": obj.tags, + "owners": [obj.creator()], + } + for obj in SavedQueryDAO.find_by_ids(tagged_queries) + ) + + return results @staticmethod - def get_tagged_objects_for_tags( - tags: Optional[list[str]] = None, obj_types: Optional[list[str]] = None + def get_tagged_objects_by_tag_names( + tag_names: Optional[list[str]] = None, obj_types: Optional[list[str]] = None ) -> list[dict[str, Any]]: """ returns a list of tagged objects filtered by tag names and object types if no filters applied returns all tagged objects """ - results: list[dict[str, Any]] = [] + tags = TagDAO.find_by_names(tag_names) if tag_names else TagDAO.find_all() + if not tags: + return [] - # dashboards - if (not obj_types) or ("dashboard" in obj_types): - dashboards = ( - db.session.query(Dashboard) - .join( - TaggedObject, - and_( - TaggedObject.object_id == Dashboard.id, - TaggedObject.object_type == ObjectType.dashboard, - ), - ) - .join(Tag, TaggedObject.tag_id == Tag.id) - .filter(not tags or Tag.name.in_(tags)) - ) - - results.extend( - { - "id": obj.id, - "type": ObjectType.dashboard.name, - "name": obj.dashboard_title, - "url": obj.url, - "changed_on": obj.changed_on, - "created_by": obj.created_by_fk, - "creator": obj.creator(), - "tags": obj.tags, - "owners": obj.owners, - } - for obj in dashboards - ) - - # charts - if (not obj_types) or ("chart" in obj_types): - charts = ( - db.session.query(Slice) - .join( - TaggedObject, - and_( - TaggedObject.object_id == Slice.id, - TaggedObject.object_type == ObjectType.chart, - ), - ) - .join(Tag, TaggedObject.tag_id == Tag.id) - .filter(not tags or Tag.name.in_(tags)) - ) - results.extend( - { - "id": obj.id, - "type": ObjectType.chart.name, - "name": obj.slice_name, - "url": obj.url, - "changed_on": obj.changed_on, - "created_by": obj.created_by_fk, - "creator": obj.creator(), - "tags": obj.tags, - "owners": obj.owners, - } - for obj in charts - ) - - # saved queries - if (not obj_types) or ("query" in obj_types): - saved_queries = ( - db.session.query(SavedQuery) - .join( - TaggedObject, - and_( - TaggedObject.object_id == SavedQuery.id, - TaggedObject.object_type == ObjectType.query, - ), - ) - .join(Tag, TaggedObject.tag_id == Tag.id) - .filter(not tags or Tag.name.in_(tags)) - ) - results.extend( - { - "id": obj.id, - "type": ObjectType.query.name, - "name": obj.label, - "url": obj.url(), - "changed_on": obj.changed_on, - "created_by": obj.created_by_fk, - "creator": obj.creator(), - "tags": obj.tags, - "owners": [obj.creator()], - } - for obj in saved_queries - ) - return results + tag_ids = [tag.id for tag in tags] + return TagDAO.get_tagged_objects_by_tag_ids(tag_ids, obj_types) @staticmethod def favorite_tag_by_id_for_current_user( # pylint: disable=invalid-name diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index c68d38879fa..54ce4481737 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -129,6 +129,7 @@ class DatasetPostSchema(Schema): external_url = fields.String(allow_none=True) normalize_columns = fields.Boolean(load_default=False) always_filter_main_dttm = fields.Boolean(load_default=False) + template_params = fields.String(allow_none=True) class DatasetPutSchema(Schema): diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 502c8ac82b4..94adb59a1cf 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -39,7 +39,6 @@ from uuid import uuid4 import pandas as pd import requests -import sqlparse from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from deprecation import deprecated @@ -55,17 +54,22 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import literal_column, quoted_name, text -from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause +from sqlalchemy.sql.expression import ColumnClause, Select, TextClause from sqlalchemy.types import TypeEngine -from sqlparse.tokens import CTE from superset import db, sql_parse from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError -from superset.sql.parse import BaseSQLStatement, SQLScript, Table -from superset.sql_parse import ParsedQuery +from superset.sql.parse import ( + BaseSQLStatement, + LimitMethod, + RLSMethod, + SQLScript, + SQLStatement, + Table, +) from superset.superset_typing import ( OAuth2ClientConfig, OAuth2State, @@ -166,14 +170,6 @@ def compile_timegrain_expression( return element.name.replace("{col}", compiler.process(element.col, **kwargs)) -class LimitMethod: # pylint: disable=too-few-public-methods - """Enum the ways that limits can be applied""" - - FETCH_MANY = "fetch_many" - WRAP_SQL = "wrap_sql" - FORCE_LIMIT = "force_limit" - - class MetricType(TypedDict, total=False): """ Type for metrics return by `get_metrics`. @@ -377,16 +373,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods allows_cte_in_subquery = True # Define alias for CTE cte_alias = "__cte" - # Whether allow LIMIT clause in the SQL - # If True, then the database engine is allowed for LIMIT clause - # If False, then the database engine is allowed for TOP clause - allow_limit_clause = True # This set will give keywords for select statements # to consider for the engines with TOP SQL parsing select_keywords: set[str] = {"SELECT"} - # This set will give the keywords for data limit statements - # to consider for the engines with TOP SQL parsing - top_keywords: set[str] = {"TOP"} # A set of disallowed connection query parameters by driver name disallow_uri_query_params: dict[str, set[str]] = {} # A Dict of query parameters that will always be used on every connection @@ -450,6 +439,21 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # the `cancel_query` value in the `extra` field of the `query` object has_query_id_before_execute = True + @classmethod + def get_rls_method(cls) -> RLSMethod: + """ + Returns the RLS method to be used for this engine. + + There are two ways to insert RLS: either replacing the table with a subquery + that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is + safer, but not supported in all databases. + """ + return ( + RLSMethod.AS_SUBQUERY + if cls.allows_subqueries and cls.allows_alias_in_select + else RLSMethod.AS_PREDICATE + ) + @classmethod def is_oauth2_enabled(cls) -> bool: return ( @@ -1119,100 +1123,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return {} - @classmethod - def apply_limit_to_sql( - cls, sql: str, limit: int, database: Database, force: bool = False - ) -> str: - """ - Alters the SQL statement to apply a LIMIT clause - - :param sql: SQL query - :param limit: Maximum number of rows to be returned by the query - :param database: Database instance - :return: SQL query with limit clause - """ - if cls.limit_method == LimitMethod.WRAP_SQL: - sql = sql.strip("\t\n ;") - qry = ( - select("*") - .select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry")) - .limit(limit) - ) - return database.compile_sqla_query(qry) - - if cls.limit_method == LimitMethod.FORCE_LIMIT: - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) - sql = parsed_query.set_or_update_query_limit(limit, force=force) - - return sql - - @classmethod - def apply_top_to_sql(cls, sql: str, limit: int) -> str: # noqa: C901 - """ - Alters the SQL statement to apply a TOP clause - :param limit: Maximum number of rows to be returned by the query - :param sql: SQL query - :return: SQL query with top clause - """ - - cte = None - sql_remainder = None - sql = sql.strip(" \t\n;") - query_limit: int | None = sql_parse.extract_top_from_query( - sql, cls.top_keywords - ) - if not limit: - final_limit = query_limit - elif int(query_limit or 0) < limit and query_limit is not None: - final_limit = query_limit - else: - final_limit = limit - if not cls.allows_cte_in_subquery: - cte, sql_remainder = sql_parse.get_cte_remainder_query(sql) - if cte: - str_statement = str(sql_remainder) - cte = cte + "\n" - else: - cte = "" - str_statement = str(sql) - str_statement = str_statement.replace("\n", " ").replace("\r", "") - - tokens = str_statement.rstrip().split(" ") - tokens = [token for token in tokens if token] - if cls.top_not_in_sql(str_statement): - selects = [ - i - for i, word in enumerate(tokens) - if word.upper() in cls.select_keywords - ] - first_select = selects[0] - if tokens[first_select + 1].upper() == "DISTINCT": - first_select += 1 - - tokens.insert(first_select + 1, "TOP") - tokens.insert(first_select + 2, str(final_limit)) - - next_is_limit_token = False - new_tokens = [] - - for token in tokens: - if token in cls.top_keywords: - next_is_limit_token = True - elif next_is_limit_token: - if token.isdigit(): - token = str(final_limit) - next_is_limit_token = False - new_tokens.append(token) - sql = " ".join(new_tokens) - return cte + sql - - @classmethod - def top_not_in_sql(cls, sql: str) -> bool: - for top_word in cls.top_keywords: - if top_word.upper() in sql.upper(): - return False - return True - @classmethod def get_limit_from_sql(cls, sql: str) -> int | None: """ @@ -1221,20 +1131,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param sql: SQL query :return: Value of limit clause in query """ - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) - return parsed_query.limit - - @classmethod - def set_or_update_query_limit(cls, sql: str, limit: int) -> str: - """ - Create a query based on original query but with new limit clause - - :param sql: SQL query - :param limit: New limit to insert/replace into query - :return: Query with new limit - """ - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) - return parsed_query.set_or_update_query_limit(limit) + script = SQLScript(sql, engine=cls.engine) + return script.statements[-1].get_limit_value() @classmethod def get_cte_query(cls, sql: str) -> str | None: @@ -1246,18 +1144,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ if not cls.allows_cte_in_subquery: - stmt = sqlparse.parse(sql)[0] - - # The first meaningful token for CTE will be with WITH - idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True) - if not (token and token.ttype == CTE): - return None - idx, token = stmt.token_next(idx) - idx = stmt.token_index(token) + 1 - - # extract rest of the SQLs after CTE - remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip() - return f"WITH {token.value},\n{cls.cte_alias} AS (\n{remainder}\n)" + statement = SQLStatement(sql, engine=cls.engine) + if statement.has_cte(): + return statement.as_cte(cls.cte_alias).format() return None @@ -1686,8 +1575,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods full_table_name = cls.quote_table(table, engine.dialect) qry = select(fields).select_from(text(full_table_name)) - if limit and cls.allow_limit_clause: - qry = qry.limit(limit) + qry = qry.limit(limit) if latest_partition: partition_query = cls.where_latest_partition( database, @@ -2088,14 +1976,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods logger.error(ex, exc_info=True) raise - @classmethod - def is_select_query(cls, parsed_query: ParsedQuery) -> bool: - """ - Determine if the statement should be considered as SELECT statement. - Some query dialects do not contain "SELECT" word in queries (eg. Kusto) - """ - return parsed_query.is_select() - @classmethod def get_column_spec( # pylint: disable=unused-argument cls, @@ -2201,10 +2081,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return False - @classmethod - def parse_sql(cls, sql: str) -> list[str]: - return [str(s).strip(" ;") for s in sqlparse.parse(sql)] - @classmethod def get_impersonation_key(cls, user: User | None) -> Any: """ diff --git a/superset/db_engine_specs/db2.py b/superset/db_engine_specs/db2.py index 6781701ac79..8dd7b00a6b9 100644 --- a/superset/db_engine_specs/db2.py +++ b/superset/db_engine_specs/db2.py @@ -20,9 +20,9 @@ from typing import Optional, Union from sqlalchemy.engine.reflection import Inspector from superset.constants import TimeGrain -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec from superset.models.core import Database -from superset.sql_parse import Table +from superset.sql.parse import LimitMethod, Table logger = logging.getLogger(__name__) diff --git a/superset/db_engine_specs/firebird.py b/superset/db_engine_specs/firebird.py index 15c4bef7bf4..d8222d81d4f 100644 --- a/superset/db_engine_specs/firebird.py +++ b/superset/db_engine_specs/firebird.py @@ -20,7 +20,8 @@ from typing import Any, Optional from sqlalchemy import types from superset.constants import TimeGrain -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec +from superset.sql.parse import LimitMethod class FirebirdEngineSpec(BaseEngineSpec): diff --git a/superset/db_engine_specs/hana.py b/superset/db_engine_specs/hana.py index 13b5674c87a..3ae70349eca 100644 --- a/superset/db_engine_specs/hana.py +++ b/superset/db_engine_specs/hana.py @@ -20,8 +20,8 @@ from typing import Any, Optional from sqlalchemy import types from superset.constants import TimeGrain -from superset.db_engine_specs.base import LimitMethod from superset.db_engine_specs.postgres import PostgresBaseEngineSpec +from superset.sql.parse import LimitMethod class HanaEngineSpec(PostgresBaseEngineSpec): diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 0d2bdd3a5d9..1f15b2834f9 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -95,6 +95,7 @@ class HiveEngineSpec(PrestoEngineSpec): allows_hidden_orderby_agg = False supports_dynamic_schema = True + supports_cross_catalog_queries = False # When running `SHOW FUNCTIONS`, what is the name of the column with the # function names? diff --git a/superset/db_engine_specs/kusto.py b/superset/db_engine_specs/kusto.py index 59c3b1f2313..9181b078592 100644 --- a/superset/db_engine_specs/kusto.py +++ b/superset/db_engine_specs/kusto.py @@ -22,13 +22,13 @@ from sqlalchemy import types from sqlalchemy.dialects.mssql.base import SMALLDATETIME from superset.constants import TimeGrain -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.exceptions import ( SupersetDBAPIDatabaseError, SupersetDBAPIOperationalError, SupersetDBAPIProgrammingError, ) -from superset.sql_parse import ParsedQuery +from superset.sql.parse import LimitMethod from superset.utils.core import GenericDataType @@ -106,7 +106,6 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method - limit_method = LimitMethod.WRAP_SQL engine = "kustokql" engine_name = "KustoKQL" time_groupby_inline = True @@ -154,15 +153,3 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method return f"""datetime({dttm.isoformat(timespec="microseconds")})""" return None - - @classmethod - def is_select_query(cls, parsed_query: ParsedQuery) -> bool: - return not parsed_query.sql.startswith(".") - - @classmethod - def parse_sql(cls, sql: str) -> list[str]: - """ - Kusto supports a single query statement, but it could include sub queries - and variables declared via let keyword. - """ - return [sql] diff --git a/superset/db_engine_specs/lib.py b/superset/db_engine_specs/lib.py index 106b9c75503..e66bcccc353 100644 --- a/superset/db_engine_specs/lib.py +++ b/superset/db_engine_specs/lib.py @@ -119,7 +119,7 @@ def diagnose(spec: type[BaseEngineSpec]) -> dict[str, Any]: output.update( { "module": spec.__module__, - "limit_method": spec.limit_method.upper(), + "limit_method": spec.limit_method.value, "joins": spec.allows_joins, "subqueries": spec.allows_subqueries, "alias_in_select": spec.allows_alias_in_select, @@ -129,7 +129,6 @@ def diagnose(spec: type[BaseEngineSpec]) -> dict[str, Any]: "order_by_not_in_select": spec.allows_hidden_orderby_agg, "expressions_in_orderby": spec.allows_hidden_cc_in_orderby, "cte_in_subquery": spec.allows_cte_in_subquery, - "limit_clause": spec.allow_limit_clause, "max_column_name": spec.max_column_name_length, "sql_comments": spec.allows_sql_comments, "escaped_colons": spec.allows_escaped_colons, @@ -223,7 +222,7 @@ def generate_table() -> list[list[Any]]: rows = [] # pylint: disable=redefined-outer-name rows.append(["Feature"] + list(info)) # header row - rows.append(["Module"] + list(db_info["module"] for db_info in info.values())) # noqa: C400 + rows.append(["Module"] + [db_info["module"] for db_info in info.values()]) # descriptive keys = [ @@ -244,14 +243,14 @@ def generate_table() -> list[list[Any]]: ] for key in keys: rows.append( - [DATABASE_DETAILS[key]] + list(db_info[key] for db_info in info.values()) # noqa: C400 + [DATABASE_DETAILS[key]] + [db_info[key] for db_info in info.values()] ) # basic for time_grain in TimeGrain: rows.append( [f"Has time grain {time_grain.name}"] - + list(db_info["time_grains"][time_grain.name] for db_info in info.values()) # noqa: C400 + + [db_info["time_grains"][time_grain.name] for db_info in info.values()] ) keys = [ "masked_encrypted_extra", @@ -259,9 +258,7 @@ def generate_table() -> list[list[Any]]: "function_names", ] for key in keys: - rows.append( - [BASIC_FEATURES[key]] + list(db_info[key] for db_info in info.values()) # noqa: C400 - ) + rows.append([BASIC_FEATURES[key]] + [db_info[key] for db_info in info.values()]) # nice to have keys = [ @@ -280,8 +277,7 @@ def generate_table() -> list[list[Any]]: ] for key in keys: rows.append( - [NICE_TO_HAVE_FEATURES[key]] - + list(db_info[key] for db_info in info.values()) # noqa: C400 + [NICE_TO_HAVE_FEATURES[key]] + [db_info[key] for db_info in info.values()] ) # advanced @@ -292,10 +288,10 @@ def generate_table() -> list[list[Any]]: ] for key in keys: rows.append( - [ADVANCED_FEATURES[key]] + list(db_info[key] for db_info in info.values()) # noqa: C400 + [ADVANCED_FEATURES[key]] + [db_info[key] for db_info in info.values()] ) - rows.append(["Score"] + list(db_info["score"] for db_info in info.values())) # noqa: C400 + rows.append(["Score"] + [db_info["score"] for db_info in info.values()]) return rows diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index c1f7e295dee..6e238e9ddfb 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -27,7 +27,7 @@ from sqlalchemy import types from sqlalchemy.dialects.mssql.base import SMALLDATETIME from superset.constants import TimeGrain -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import SupersetErrorType from superset.models.sql_types.mssql_sql_types import GUID from superset.utils.core import GenericDataType @@ -52,10 +52,8 @@ CONNECTION_HOST_DOWN_REGEX = re.compile( class MssqlEngineSpec(BaseEngineSpec): engine = "mssql" engine_name = "Microsoft SQL Server" - limit_method = LimitMethod.WRAP_SQL max_column_name_length = 128 allows_cte_in_subquery = False - allow_limit_clause = False supports_multivalues_insert = True _time_grain_expressions = { diff --git a/superset/db_engine_specs/ocient.py b/superset/db_engine_specs/ocient.py index a7b97ed6996..75889c706c2 100644 --- a/superset/db_engine_specs/ocient.py +++ b/superset/db_engine_specs/ocient.py @@ -225,7 +225,6 @@ def _find_columns_to_sanitize(cursor: Any) -> list[PlacedSanitizeFunc]: class OcientEngineSpec(BaseEngineSpec): engine = "ocient" engine_name = "Ocient" - # limit_method = LimitMethod.WRAP_SQL force_column_alias_quotes = True max_column_name_length = 30 diff --git a/superset/db_engine_specs/oracle.py b/superset/db_engine_specs/oracle.py index f03cea49120..1df5736b824 100644 --- a/superset/db_engine_specs/oracle.py +++ b/superset/db_engine_specs/oracle.py @@ -20,13 +20,12 @@ from typing import Any, Optional from sqlalchemy import types from superset.constants import TimeGrain -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec class OracleEngineSpec(BaseEngineSpec): engine = "oracle" engine_name = "Oracle" - limit_method = LimitMethod.WRAP_SQL force_column_alias_quotes = True max_column_name_length = 30 diff --git a/superset/db_engine_specs/teradata.py b/superset/db_engine_specs/teradata.py index 887add24e90..08c9e9b7c99 100644 --- a/superset/db_engine_specs/teradata.py +++ b/superset/db_engine_specs/teradata.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.db_engine_specs.base import BaseEngineSpec class TeradataEngineSpec(BaseEngineSpec): @@ -23,11 +23,8 @@ class TeradataEngineSpec(BaseEngineSpec): engine = "teradatasql" engine_name = "Teradata" - limit_method = LimitMethod.WRAP_SQL max_column_name_length = 30 # since 14.10 this is 128 - allow_limit_clause = False select_keywords = {"SELECT", "SEL"} - top_keywords = {"TOP", "SAMPLE"} _time_grain_expressions = { None: "{col}", diff --git a/superset/examples/flights.py b/superset/examples/flights.py index 4d8b04e42a0..b996fafcbb1 100644 --- a/superset/examples/flights.py +++ b/superset/examples/flights.py @@ -47,9 +47,9 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None: ) airports = airports.set_index("IATA_CODE") - pdf[ # pylint: disable=unsupported-assignment-operation,useless-suppression - "ds" - ] = pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str) + pdf["ds"] = ( + pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str) + ) pdf.ds = pd.to_datetime(pdf.ds) pdf.drop(columns=["DAY", "MONTH", "YEAR"]) pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG") diff --git a/superset/exceptions.py b/superset/exceptions.py index c6105b44654..0ca6f0c70be 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -432,3 +432,54 @@ class TableNotFoundException(SupersetErrorException): level=ErrorLevel.ERROR, ) ) + + +class SupersetDMLNotAllowedException(SupersetErrorException): + def __init__(self) -> None: + error = SupersetError( + message=_( + "This database does not allow for DDL/DML, but the query mutates " + "data. Please contact your administrator for more assistance." + ), + error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, + level=ErrorLevel.ERROR, + ) + super().__init__(error) + + +class SupersetInvalidCTASException(SupersetErrorException): + def __init__(self) -> None: + error = SupersetError( + message=_( + "CTAS (create table as select) can only be run with a query where " + "the last statement is a SELECT. Please make sure your query has " + "a SELECT as its last statement. Then, try running your query again." + ), + error_type=SupersetErrorType.INVALID_CTAS_QUERY_ERROR, + level=ErrorLevel.ERROR, + ) + super().__init__(error) + + +class SupersetInvalidCVASException(SupersetErrorException): + def __init__(self) -> None: + error = SupersetError( + message=_( + "CVAS (create view as select) can only be run with a query with " + "a single SELECT statement. Please make sure your query has only " + "a SELECT statement. Then, try running your query again." + ), + error_type=SupersetErrorType.INVALID_CVAS_QUERY_ERROR, + level=ErrorLevel.ERROR, + ) + super().__init__(error) + + +class SupersetResultsBackendNotConfigureException(SupersetErrorException): + def __init__(self) -> None: + error = SupersetError( + message=_("Results backend is not configured."), + error_type=SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR, + level=ErrorLevel.ERROR, + ) + super().__init__(error) diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index aae49068506..f00f8fb146e 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -171,7 +171,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods from superset.views.error_handling import set_app_error_handlers from superset.views.explore import ExplorePermalinkView, ExploreView from superset.views.log.api import LogRestApi - from superset.views.log.views import LogModelView + from superset.views.logs import ActionLogView from superset.views.roles import RolesListView from superset.views.sql_lab.views import ( SavedQueryView, @@ -224,6 +224,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods appbuilder.add_api(TagRestApi) appbuilder.add_api(SqlLabRestApi) appbuilder.add_api(SqlLabPermalinkRestApi) + appbuilder.add_api(LogRestApi) # # Setup regular views # @@ -289,6 +290,14 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods category_label=__("Security"), ) + appbuilder.add_view( + ActionLogView, + "Action Logs", + label=__("Action Logs"), + category="Security", + category_label=__("Security"), + ) + appbuilder.add_view( DynamicPluginsView, "Plugins", @@ -369,19 +378,6 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods category="Manage", menu_cond=lambda: feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"), ) - appbuilder.add_api(LogRestApi) - appbuilder.add_view( - LogModelView, - "Action Log", - label=__("Action Log"), - category="Security", - category_label=__("Security"), - icon="fa-list-ol", - menu_cond=lambda: ( - self.config["FAB_ADD_SECURITY_VIEWS"] - and self.config["SUPERSET_LOG_VIEW"] - ), - ) appbuilder.add_api(SecurityRestApi) # # Conditionally setup email views diff --git a/superset/jinja_context.py b/superset/jinja_context.py index c4182b136eb..837150f0333 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -22,7 +22,7 @@ import re from dataclasses import dataclass from datetime import datetime from functools import lru_cache, partial -from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Union +from typing import Any, Callable, cast, TYPE_CHECKING, TypedDict, Union import dateutil from flask import current_app, g, has_request_context, request @@ -109,6 +109,7 @@ class ExtraCache: r"current_user_id\([^()]*\)|" r"current_username\([^()]*\)|" r"current_user_email\([^()]*\)|" + r"current_user_rls_rules\([^()]*\)|" r"current_user_roles\([^()]*\)|" r"cache_key_wrapper\([^()]*\)|" r"url_param\([^()]*\)" @@ -118,12 +119,12 @@ class ExtraCache: def __init__( # pylint: disable=too-many-arguments self, - extra_cache_keys: Optional[list[Any]] = None, - applied_filters: Optional[list[str]] = None, - removed_filters: Optional[list[str]] = None, - database: Optional[Database] = None, - dialect: Optional[Dialect] = None, - table: Optional[SqlaTable] = None, + extra_cache_keys: list[Any] | None = None, + applied_filters: list[str] | None = None, + removed_filters: list[str] | None = None, + database: Database | None = None, + dialect: Dialect | None = None, + table: SqlaTable | None = None, ): self.extra_cache_keys = extra_cache_keys self.applied_filters = applied_filters if applied_filters is not None else [] @@ -132,7 +133,7 @@ class ExtraCache: self.dialect = dialect self.table = table - def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]: + def current_user_id(self, add_to_cache_keys: bool = True) -> int | None: """ Return the user ID of the user who is currently logged in. @@ -146,7 +147,7 @@ class ExtraCache: return user_id return None - def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]: + def current_username(self, add_to_cache_keys: bool = True) -> str | None: """ Return the username of the user who is currently logged in. @@ -160,7 +161,7 @@ class ExtraCache: return username return None - def current_user_email(self, add_to_cache_keys: bool = True) -> Optional[str]: + def current_user_email(self, add_to_cache_keys: bool = True) -> str | None: """ Return the email address of the user who is currently logged in. @@ -193,6 +194,31 @@ class ExtraCache: except Exception: # pylint: disable=broad-except return None + def current_user_rls_rules(self) -> list[str] | None: + """ + Return the row level security rules applied to the current user and dataset. + """ + if not self.table: + return None + + rls_rules = ( + sorted( + [ + rule["clause"] + for rule in security_manager.get_guest_rls_filters(self.table) + ] + ) + if security_manager.is_guest_user() + else sorted( + [rule.clause for rule in security_manager.get_rls_filters(self.table)] + ) + ) + if not rls_rules: + return None + + self.cache_key_wrapper(json.dumps(rls_rules)) + return rls_rules + def cache_key_wrapper(self, key: Any) -> Any: """ Adds values to a list that is added to the query object used for calculating a @@ -213,10 +239,10 @@ class ExtraCache: def url_param( self, param: str, - default: Optional[str] = None, + default: str | None = None, add_to_cache_keys: bool = True, escape_result: bool = True, - ) -> Optional[str]: + ) -> str | None: """ Read a url or post parameter and use it in your SQL Lab query. @@ -259,7 +285,7 @@ class ExtraCache: return result def filter_values( - self, column: str, default: Optional[str] = None, remove_filter: bool = False + self, column: str, default: str | None = None, remove_filter: bool = False ) -> list[Any]: """Gets a values for a particular filter as a list @@ -524,7 +550,7 @@ def validate_context_types(context: dict[str, Any]) -> dict[str, Any]: def validate_template_context( - engine: Optional[str], context: dict[str, Any] + engine: str | None, context: dict[str, Any] ) -> dict[str, Any]: if engine and engine in context: # validate engine context separately to allow for engine-specific methods @@ -543,7 +569,7 @@ class WhereInMacro: # pylint: disable=too-few-public-methods def __call__( self, values: list[Any], - mark: Optional[str] = None, + mark: str | None = None, default_to_none: bool = False, ) -> str | None: """ @@ -605,17 +631,17 @@ class BaseTemplateProcessor: Base class for database-specific jinja context """ - engine: Optional[str] = None + engine: str | None = None # pylint: disable=too-many-arguments def __init__( self, database: "Database", - query: Optional["Query"] = None, - table: Optional["SqlaTable"] = None, - extra_cache_keys: Optional[list[Any]] = None, - removed_filters: Optional[list[str]] = None, - applied_filters: Optional[list[str]] = None, + query: "Query" | None = None, + table: "SqlaTable" | None = None, + extra_cache_keys: list[Any] | None = None, + removed_filters: list[str] | None = None, + applied_filters: list[str] | None = None, **kwargs: Any, ) -> None: self._database = database @@ -641,6 +667,12 @@ class BaseTemplateProcessor: self._context.update(kwargs) self._context.update(context_addons()) + def get_context(self) -> dict[str, Any]: + """ + Returns the current template context. + """ + return self._context.copy() + def process_template(self, sql: str, **kwargs: Any) -> str: """Processes a sql template @@ -661,7 +693,7 @@ class BaseTemplateProcessor: class JinjaTemplateProcessor(BaseTemplateProcessor): - def _parse_datetime(self, dttm: str) -> Optional[datetime]: + def _parse_datetime(self, dttm: str) -> datetime | None: """ Try to parse a datetime and default to None in the worst case. @@ -713,6 +745,9 @@ class JinjaTemplateProcessor(BaseTemplateProcessor): "current_user_roles": partial( safe_proxy, extra_cache.current_user_roles ), + "current_user_rls_rules": partial( + safe_proxy, extra_cache.current_user_rls_rules + ), "cache_key_wrapper": partial(safe_proxy, extra_cache.cache_key_wrapper), "filter_values": partial(safe_proxy, extra_cache.filter_values), "get_filters": partial(safe_proxy, extra_cache.get_filters), @@ -757,14 +792,12 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor): } @staticmethod - def _schema_table( - table_name: str, schema: Optional[str] - ) -> tuple[str, Optional[str]]: + def _schema_table(table_name: str, schema: str | None) -> tuple[str, str | None]: if "." in table_name: schema, table_name = table_name.split(".") return table_name, schema - def first_latest_partition(self, table_name: str) -> Optional[str]: + def first_latest_partition(self, table_name: str) -> str | None: """ Gets the first value in the array of all latest partitions @@ -776,7 +809,7 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor): latest_partitions = self.latest_partitions(table_name) return latest_partitions[0] if latest_partitions else None - def latest_partitions(self, table_name: str) -> Optional[list[str]]: + def latest_partitions(self, table_name: str) -> list[str] | None: """ Gets the array of all latest partitions @@ -858,8 +891,8 @@ def get_template_processors() -> dict[str, Any]: def get_template_processor( database: "Database", - table: Optional["SqlaTable"] = None, - query: Optional["Query"] = None, + table: "SqlaTable" | None = None, + query: "Query" | None = None, **kwargs: Any, ) -> BaseTemplateProcessor: if feature_flag_manager.is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"): @@ -874,9 +907,9 @@ def get_template_processor( def dataset_macro( dataset_id: int, include_metrics: bool = False, - columns: Optional[list[str]] = None, - from_dttm: Optional[datetime] = None, - to_dttm: Optional[datetime] = None, + columns: list[str] | None = None, + from_dttm: datetime | None = None, + to_dttm: datetime | None = None, ) -> str: """ Given a dataset ID, return the SQL that represents it. @@ -958,7 +991,7 @@ def metric_macro( env: Environment, context: dict[str, Any], metric_key: str, - dataset_id: Optional[int] = None, + dataset_id: int | None = None, ) -> str: """ Given a metric key, returns its syntax. diff --git a/superset/migrations/shared/migrate_viz/base.py b/superset/migrations/shared/migrate_viz/base.py index fe434a52f4b..929039c0f67 100644 --- a/superset/migrations/shared/migrate_viz/base.py +++ b/superset/migrations/shared/migrate_viz/base.py @@ -44,6 +44,7 @@ class Slice(Base): # type: ignore FORM_DATA_BAK_FIELD_NAME = "form_data_bak" +QUERIES_BAK_FIELD_NAME = "queries_bak" class MigrateViz: @@ -156,14 +157,24 @@ class MigrateViz: # because a source viz can be mapped to different target viz types slc.viz_type = clz.target_viz_type - # only backup params - slc.params = json.dumps( - {**clz.data, FORM_DATA_BAK_FIELD_NAME: form_data_bak} - ) + backup = {FORM_DATA_BAK_FIELD_NAME: form_data_bak} + + query_context = try_load_json(slc.query_context) + + if query_context: + if "form_data" in query_context: + query_context["form_data"] = clz.data + + queries_bak = copy.deepcopy(query_context["queries"]) + + queries = clz._build_query()["queries"] + query_context["queries"] = queries - if "form_data" in (query_context := try_load_json(slc.query_context)): - query_context["form_data"] = clz.data slc.query_context = json.dumps(query_context) + backup[QUERIES_BAK_FIELD_NAME] = queries_bak + + slc.params = json.dumps({**clz.data, **backup}) + except Exception as e: logger.warning(f"Failed to migrate slice {slc.id}: {e}") @@ -177,9 +188,12 @@ class MigrateViz: slc.params = json.dumps(form_data_bak) slc.viz_type = form_data_bak.get("viz_type") query_context = try_load_json(slc.query_context) + queries_bak = form_data.get(QUERIES_BAK_FIELD_NAME, {}) + query_context["queries"] = queries_bak if "form_data" in query_context: query_context["form_data"] = form_data_bak - slc.query_context = json.dumps(query_context) + + slc.query_context = json.dumps(query_context) except Exception as e: logger.warning(f"Failed to downgrade slice {slc.id}: {e}") @@ -205,3 +219,6 @@ class MigrateViz: lambda current, total: logger.info(f"Downgraded {current}/{total} charts"), ): cls.downgrade_slice(slc) + + def _build_query(self) -> Any | dict[str, Any]: + """Builds a query based on the form data.""" diff --git a/superset/migrations/shared/migrate_viz/processors.py b/superset/migrations/shared/migrate_viz/processors.py index 44e5aacfb02..c60b6bd42fc 100644 --- a/superset/migrations/shared/migrate_viz/processors.py +++ b/superset/migrations/shared/migrate_viz/processors.py @@ -16,6 +16,32 @@ # under the License. from typing import Any +from superset.migrations.shared.migrate_viz.query_functions import ( + build_query_context, + contribution_operator, + ensure_is_array, + extract_extra_metrics, + flatten_operator, + get_column_label, + get_metric_label, + get_x_axis_column, + histogram_operator, + is_physical_column, + is_time_comparison, + is_x_axis_set, + normalize_order_by, + pivot_operator, + prophet_operator, + rank_operator, + remove_form_data_suffix, + rename_operator, + resample_operator, + retain_form_data_suffix, + rolling_window_operator, + sort_operator, + time_compare_operator, + time_compare_pivot_operator, +) from superset.utils.core import as_list from .base import MigrateViz @@ -35,6 +61,19 @@ class MigrateTreeMap(MigrateViz): ): self.data["metric"] = self.data["metrics"][0] + def _build_query(self) -> dict[str, Any]: + metric = self.data.get("metric") + sort_by_metric = self.data.get("sort_by_metric") + + def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]: + new_query_object = base_query_object.copy() + + if sort_by_metric: + new_query_object["orderby"] = [[metric, False]] + return [new_query_object] + + return build_query_context(self.data, process) + class MigratePivotTable(MigrateViz): source_viz_type = "pivot_table" @@ -70,6 +109,58 @@ class MigratePivotTable(MigrateViz): self.data["rowOrder"] = "value_z_to_a" + def _build_query(self) -> dict[str, Any]: + groupby_columns = self.data.get("groupbyColumns", []) + groupby_rows = self.data.get("groupbyRows", []) + extra_form_data = self.data.get("extra_form_data", {}) + time_grain_sqla = extra_form_data.get("time_grain_sqla") or self.data.get( + "time_grain_sqla" + ) + + unique_columns = ensure_is_array(groupby_columns) + ensure_is_array( + groupby_rows + ) + + columns = [] + for col in unique_columns: + if ( + is_physical_column(col) + and time_grain_sqla + and ( + self.data.get("temporal_columns_lookup", {}).get(col) + or self.data.get("granularity_sqla") == col + ) + ): + col_dict = { + "timeGrain": time_grain_sqla, + "columnType": "BASE_AXIS", + "sqlExpression": col, + "label": col, + "expressionType": "SQL", + } + if col_dict not in columns: + columns.append(col_dict) + else: + if col not in columns: + columns.append(col) + + def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]: + series_limit_metric = base_query_object.get("series_limit_metric") + metrics = base_query_object.get("metrics") + order_desc = base_query_object.get("order_desc") + orderby = None + if series_limit_metric: + orderby = [[series_limit_metric, not order_desc]] + elif isinstance(metrics, list) and metrics and metrics[0]: + orderby = [[metrics[0], not order_desc]] + new_query_object = base_query_object.copy() + if orderby is not None: + new_query_object["orderby"] = orderby + new_query_object["columns"] = columns + return [new_query_object] + + return build_query_context(self.data, process) + class MigrateDualLine(MigrateViz): has_x_axis_control = True @@ -94,12 +185,73 @@ class MigrateDualLine(MigrateViz): super()._migrate_temporal_filter(rv_data) rv_data["adhoc_filters_b"] = rv_data.get("adhoc_filters") or [] + def _build_query(self) -> dict[str, Any]: + base_form_data = self.data.copy() + form_data1 = remove_form_data_suffix(base_form_data, "_b") + form_data2 = retain_form_data_suffix(base_form_data, "_b") + + def process_fn(fd: dict[str, Any]) -> dict[str, Any]: + def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]: + query_object = base_query_object.copy() + query_object["columns"] = ( + ensure_is_array(get_x_axis_column(self.data)) + if is_x_axis_set(self.data) + else [] + ) + ensure_is_array(fd.get("groupby")) + query_object["series_columns"] = fd.get("groupby") + if not is_x_axis_set(self.data): + query_object["is_timeseries"] = True + pivot_operator_runtime = ( + time_compare_pivot_operator(fd, query_object) + if is_time_comparison(fd, query_object) + else pivot_operator(fd, query_object) + ) + tmp_query_object = query_object.copy() + tmp_query_object["time_offsets"] = ( + fd.get("time_compare") + if is_time_comparison(fd, query_object) + else [] + ) + tmp_query_object["post_processing"] = [ + pivot_operator_runtime, + rolling_window_operator(fd, query_object), + time_compare_operator(fd, query_object), + resample_operator(fd, query_object), + rename_operator(fd, query_object), + flatten_operator(fd, query_object), + ] + + if tmp_query_object["series_columns"] is None: + tmp_query_object.pop("series_columns") + return [normalize_order_by(tmp_query_object)] + + return build_query_context(fd, process) + + query_contexts = [process_fn(form_data1), process_fn(form_data2)] + qc0 = query_contexts[0] + qc1 = query_contexts[1] + merged = qc0.copy() + merged["queries"] = qc0.get("queries", []) + qc1.get("queries", []) + return merged + class MigrateSunburst(MigrateViz): source_viz_type = "sunburst" target_viz_type = "sunburst_v2" rename_keys = {"groupby": "columns"} + def _build_query(self) -> dict[str, Any]: + metric = self.data.get("metric") + sort_by_metric = self.data.get("sort_by_metric") + + def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]: + result = base_query_object.copy() + if sort_by_metric: + result["orderby"] = [[metric, False]] + return [result] + + return build_query_context(self.data, process) + class TimeseriesChart(MigrateViz): has_x_axis_control = True @@ -142,9 +294,11 @@ class TimeseriesChart(MigrateViz): if (rolling_type := self.data.get("rolling_type")) and rolling_type != "None": self.data["rolling_type"] = rolling_type - if time_compare := self.data.get("time_compare"): + if (time_compare := self.data.get("time_compare")) is not None: self.data["time_compare"] = [ - value + " ago" for value in as_list(time_compare) if value + v if v.endswith(" ago") else v + " ago" + for value in as_list(time_compare) + if (v := value.strip()) ] comparison_type = self.data.get("comparison_type") or "values" @@ -155,6 +309,63 @@ class TimeseriesChart(MigrateViz): if x_ticks_layout := self.data.get("x_ticks_layout"): self.data["x_ticks_layout"] = 45 if x_ticks_layout == "45°" else 0 + def _build_query(self) -> dict[str, Any]: + groupby = self.data.get("groupby") + + def query_builder(base_query_object: dict[str, Any]) -> list[dict[str, Any]]: + """ + The `pivot_operator_in_runtime` determines how to pivot the dataframe + returned from the raw query. + 1. If it's a time compared query, there will return a pivoted + dataframe that append time compared metrics. + """ + extra_metrics = extract_extra_metrics(self.data) + + pivot_operator_in_runtime = ( + time_compare_pivot_operator(self.data, base_query_object) + if is_time_comparison(self.data, base_query_object) + else pivot_operator(self.data, base_query_object) + ) + + columns = ( + ensure_is_array(get_x_axis_column(self.data)) + if is_x_axis_set(self.data) + else [] + ) + ensure_is_array(groupby) + + time_offsets = ( + self.data.get("time_compare") + if is_time_comparison(self.data, base_query_object) + else [] + ) + + result = { + **base_query_object, + "metrics": (base_query_object.get("metrics") or []) + extra_metrics, + "columns": columns, + "series_columns": groupby, + **({"is_timeseries": True} if not is_x_axis_set(self.data) else {}), + # todo: move `normalize_order_by to extract_query_fields` + "orderby": normalize_order_by(base_query_object).get("orderby"), + "time_offsets": time_offsets, + "post_processing": [ + pivot_operator_in_runtime, + rolling_window_operator(self.data, base_query_object), + time_compare_operator(self.data, base_query_object), + resample_operator(self.data, base_query_object), + rename_operator(self.data, base_query_object), + contribution_operator(self.data, base_query_object, time_offsets), + sort_operator(self.data, base_query_object), + flatten_operator(self.data, base_query_object), + # todo: move prophet before flatten + prophet_operator(self.data, base_query_object), + ], + } + + return [result] + + return build_query_context(self.data, query_builder) + class MigrateLineChart(TimeseriesChart): source_viz_type = "line" @@ -173,6 +384,9 @@ class MigrateLineChart(TimeseriesChart): self.target_viz_type = "echarts_timeseries_step" self.data["seriesType"] = "end" + def _build_query(self) -> dict[str, Any]: + return super()._build_query() + class MigrateAreaChart(TimeseriesChart): source_viz_type = "area" @@ -194,6 +408,9 @@ class MigrateAreaChart(TimeseriesChart): self.data["opacity"] = 0.7 + def _build_query(self) -> dict[str, Any]: + return super()._build_query() + class MigrateBarChart(TimeseriesChart): source_viz_type = "bar" @@ -208,6 +425,9 @@ class MigrateBarChart(TimeseriesChart): self.data["stack"] = "Stack" if self.data.get("bar_stacked") else None + def _build_query(self) -> dict[str, Any]: + return super()._build_query() + class MigrateDistBarChart(TimeseriesChart): source_viz_type = "dist_bar" @@ -238,6 +458,9 @@ class MigrateDistBarChart(TimeseriesChart): self.data["stack"] = "Stack" if self.data.get("bar_stacked") else None self.data["x_ticks_layout"] = 45 + def _build_query(self) -> dict[str, Any]: + return super()._build_query() + class MigrateBubbleChart(MigrateViz): source_viz_type = "bubble" @@ -267,6 +490,30 @@ class MigrateBubbleChart(MigrateViz): # Truncate y-axis by default to preserve layout self.data["y_axis_showminmax"] = True + def _build_query(self) -> dict[str, Any]: + columns = ensure_is_array(self.data.get("entity")) + ensure_is_array( + self.data.get("series") + ) + + def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]: + if base_query_object.get("orderby"): + orderby = [ + [ + base_query_object["orderby"][0], + not base_query_object.get("order_desc", False), + ] + ] + else: + orderby = None + + new_query_object = {**base_query_object, "columns": columns} + if orderby is not None: + new_query_object["orderby"] = orderby + + return [new_query_object] + + return build_query_context(self.data, process) + class MigrateHeatmapChart(MigrateViz): source_viz_type = "heatmap" @@ -282,6 +529,53 @@ class MigrateHeatmapChart(MigrateViz): def _pre_action(self) -> None: self.data["legend_type"] = "continuous" + def _build_query(self) -> dict[str, Any]: + groupby = self.data.get("groupby") + normalize_across = self.data.get("normalize_across") + sort_x_axis = self.data.get("sort_x_axis") + sort_y_axis = self.data.get("sort_y_axis") + x_axis = self.data.get("x_axis") + + metric = get_metric_label(self.data.get("metric")) + + columns = ensure_is_array(get_x_axis_column(self.data)) + ensure_is_array( + groupby + ) + + orderby = [] + if sort_x_axis: + chosen = metric if "value" in sort_x_axis else columns[0] + ascending = "asc" in sort_x_axis + orderby.append([chosen, ascending]) + if sort_y_axis: + chosen = metric if "value" in sort_y_axis else columns[1] + ascending = "asc" in sort_y_axis + orderby.append([chosen, ascending]) + + if normalize_across == "x": + group_by = get_column_label(x_axis) + elif normalize_across == "y": + group_by = get_column_label(groupby) + else: + group_by = None + + def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]: + new_query_object = base_query_object.copy() + new_query_object["columns"] = columns + if orderby: + new_query_object["orderby"] = orderby + new_query_object["post_processing"] = [ + rank_operator( + self.data, + base_query_object, + {"metric": metric, "group_by": group_by}, + ) + ] + + return [new_query_object] + + return build_query_context(self.data, process) + class MigrateHistogramChart(MigrateViz): source_viz_type = "histogram" @@ -305,6 +599,22 @@ class MigrateHistogramChart(MigrateViz): if not groupby: self.data["groupby"] = [] + def _build_query(self) -> dict[str, Any]: + column = self.data.get("column") + groupby = self.data.get("groupby", []) + + def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]: + result = base_query_object.copy() + result["columns"] = groupby + [column] + result["post_processing"] = [ + histogram_operator(self.data, base_query_object) + ] + if "metrics" in result.keys(): + result.pop("metrics", None) + return [result] + + return build_query_context(self.data, process) + class MigrateSankey(MigrateViz): source_viz_type = "sankey" @@ -316,3 +626,19 @@ class MigrateSankey(MigrateViz): if groupby and len(groupby) > 1: self.data["source"] = groupby[0] self.data["target"] = groupby[1] + + def _build_query(self) -> dict[str, Any]: + metric = self.data.get("metric") + sort_by_metric = self.data.get("sort_by_metric") + source = self.data.get("source") + target = self.data.get("target") + groupby = [source, target] + + def process(base_query_object: dict[str, Any]) -> list[dict[str, Any]]: + result = base_query_object.copy() + result["groupby"] = groupby + if sort_by_metric: + result["orderby"] = [[metric, False]] + return [result] + + return build_query_context(self.data, process) diff --git a/superset/migrations/shared/migrate_viz/query_functions.py b/superset/migrations/shared/migrate_viz/query_functions.py new file mode 100644 index 00000000000..736a38c5772 --- /dev/null +++ b/superset/migrations/shared/migrate_viz/query_functions.py @@ -0,0 +1,1507 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import math +from enum import Enum +from typing import Any, Dict, List, Optional, Union + + +class RollingType(Enum): + Mean = "mean" + Sum = "sum" + Std = "std" + Cumsum = "cumsum" + + +class ComparisonType(Enum): + Values = "values" + Difference = "difference" + Percentage = "percentage" + Ratio = "ratio" + + +class DatasourceType(Enum): + Table = "table" + Query = "query" + Dataset = "dataset" + SlTable = "sl_table" + SavedQuery = "saved_query" + + +UNARY_OPERATORS = ["IS NOT NULL", "IS NULL"] +BINARY_OPERATORS = [ + "==", + "!=", + ">", + "<", + ">=", + "<=", + "ILIKE", + "LIKE", + "NOT LIKE", + "REGEX", + "TEMPORAL_RANGE", +] +SET_OPERATORS = ["IN", "NOT IN"] + +unary_operator_set = set(UNARY_OPERATORS) +binary_operator_set = set(BINARY_OPERATORS) +set_operator_set = set(SET_OPERATORS) + + +class DatasourceKey: + def __init__(self, key: str): + id_str, type_str = key.split("__", 1) + self.id = int(id_str) + # Default to Table; if type_str is 'query', then use Query. + self.type = DatasourceType.Table + if type_str == "query": + self.type = DatasourceType.Query + + def __str__(self) -> str: + return f"{self.id}__{self.type.value}" + + def to_object(self) -> dict[str, Any]: + return { + "id": self.id, + "type": self.type.value, + } + + +TIME_COMPARISON_SEPARATOR = "__" +DTTM_ALIAS = "__timestamp" +NO_TIME_RANGE = "No filter" + +EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS = [ + "relative_start", + "relative_end", + "time_grain_sqla", +] + +EXTRA_FORM_DATA_APPEND_KEYS = [ + "adhoc_filters", + "filters", + "interactive_groupby", + "interactive_highlight", + "interactive_drilldown", + "custom_form_data", +] + +EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS = { + "granularity": "granularity", + "granularity_sqla": "granularity", + "time_column": "time_column", + "time_grain": "time_grain", + "time_range": "time_range", +} + +EXTRA_FORM_DATA_OVERRIDE_REGULAR_KEYS = list( + EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS.keys() +) + +EXTRA_FORM_DATA_OVERRIDE_KEYS = ( + EXTRA_FORM_DATA_OVERRIDE_REGULAR_KEYS + EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS +) + + +def ensure_is_array(value: Optional[Union[List[Any], Any]] = None) -> List[Any]: + """ + Ensure a nullable value input is a list. Useful when consolidating + input format from a select control. + """ + if value is None: + return [] + return value if isinstance(value, list) else [value] + + +def is_empty(value: Any) -> bool: + """ + A simple implementation similar to lodash's isEmpty. + Returns True if value is None or an empty collection. + """ + if value is None: + return True + if isinstance(value, (list, dict, str, tuple, set)): + return len(value) == 0 + return False + + +def is_saved_metric(metric: Any) -> bool: + """Return True if metric is a saved metric (str).""" + return isinstance(metric, str) + + +def is_adhoc_metric_simple(metric: Any) -> bool: + """Return True if metric dict is a simple adhoc metric.""" + return ( + not isinstance(metric, str) + and isinstance(metric, dict) + and metric.get("expressionType") == "SIMPLE" + ) + + +def is_adhoc_metric_sql(metric: Any) -> bool: + """Return True if metric dict is an SQL adhoc metric.""" + return ( + not isinstance(metric, str) + and isinstance(metric, dict) + and metric.get("expressionType") == "SQL" + ) + + +def is_query_form_metric(metric: Any) -> bool: + """Return True if metric is of any query form type.""" + return ( + is_saved_metric(metric) + or is_adhoc_metric_simple(metric) + or is_adhoc_metric_sql(metric) + ) + + +def get_metric_label(metric: Any | dict[str, Any]) -> Any | dict[str, Any]: + """ + Get the label for a given metric. + + Args: + metric (dict): The metric object. + + Returns: + dict: The label of the metric. + """ + if is_saved_metric(metric): + return metric + if "label" in metric and metric["label"]: + return metric["label"] + if is_adhoc_metric_simple(metric): + column_name = metric["column"].get("columnName") or metric["column"].get( + "column_name" + ) + return f"{metric['aggregate']}({column_name})" + return metric["sqlExpression"] + + +def extract_extra_metrics(form_data: Dict[str, Any]) -> List[Any]: + """ + Extract extra metrics from the form data. + + Args: + form_data (Dict[str, Any]): The query form data. + + Returns: + List[Any]: A list of extra metrics. + """ + groupby = form_data.get("groupby", []) + timeseries_limit_metric = form_data.get("timeseries_limit_metric") + x_axis_sort = form_data.get("x_axis_sort") + metrics = form_data.get("metrics", []) + + extra_metrics = [] + limit_metric = ( + ensure_is_array(timeseries_limit_metric)[0] if timeseries_limit_metric else None + ) + + if ( + not groupby + and limit_metric + and get_metric_label(limit_metric) == x_axis_sort + and not any(get_metric_label(metric) == x_axis_sort for metric in metrics) + ): + extra_metrics.append(limit_metric) + + return extra_metrics + + +def get_metric_offsets_map( + form_data: dict[str, List[str]], query_object: dict[str, List[str]] +) -> dict[str, Any]: + """ + Return a dictionary mapping metric offset-labels to metric-labels. + + Args: + form_data (Dict[str, List[str]]): The form data containing time comparisons. + query_object (Dict[str, List[str]]): The query object containing metrics. + + Returns: + Dict[str, str]: A dictionary with offset-labels as keys and metric-labels + as values. + """ + query_metrics = ensure_is_array(query_object.get("metrics", [])) + time_offsets = ensure_is_array(form_data.get("time_compare", [])) + + metric_labels = [get_metric_label(metric) for metric in query_metrics] + metric_offset_map = {} + + for metric in metric_labels: + for offset in time_offsets: + key = f"{metric}{TIME_COMPARISON_SEPARATOR}{offset}" + metric_offset_map[key] = metric + + return metric_offset_map + + +def is_time_comparison(form_data: dict[str, Any], query_object: dict[str, Any]) -> bool: + """ + Determine if the query involves a time comparison. + + Args: + form_data (dict): The form data containing query parameters. + query_object (dict): The query object. + + Returns: + bool: True if it is a time comparison, False otherwise. + """ + comparison_type = form_data.get("comparison_type") + metric_offset_map = get_metric_offsets_map(form_data, query_object) + + return ( + comparison_type in [ct.value for ct in ComparisonType] + and len(metric_offset_map) > 0 + ) + + +def ensure_is_int(value: Any, default_value: Any = None) -> Any | float: + """ + Convert the given value to an integer. + If conversion fails, returns default_value if provided, + otherwise returns NaN (as float('nan')). + """ + try: + val = int(str(value)) + except (ValueError, TypeError): + return default_value if default_value is not None else float("nan") + return val + + +def is_physical_column(column: Any = None) -> bool: + """Return True if column is a physical column (string).""" + return isinstance(column, str) + + +def is_adhoc_column(column: Any = None) -> bool: + """Return True if column is an adhoc column (object with SQL expression).""" + if type(column) is not dict: + return False + return ( + "sqlExpression" in column.keys() + and column["sqlExpression"] is not None + and "label" in column.keys() + and column["label"] is not None + and ("sqlExpression" not in column.keys() or column["expressionType"] == "SQL") + ) + + +def is_query_form_column(column: Any) -> bool: + """Return True if column is either physical or adhoc.""" + return is_physical_column(column) or is_adhoc_column(column) + + +def is_x_axis_set(form_data: dict[str, Any]) -> bool: + """Return True if the x_axis is specified in form_data.""" + return is_query_form_column(form_data.get("x_axis")) + + +def get_x_axis_column(form_data: dict[str, Any]) -> Optional[Any]: + """Return x_axis column.""" + if not (form_data.get("granularity_sqla") or form_data.get("x_axis")): + return None + + if is_x_axis_set(form_data): + return form_data.get("x_axis") + + return DTTM_ALIAS + + +def get_column_label(column: Any) -> Optional[str]: + """Return the string label for a column.""" + if is_physical_column(column): + return column + if column and column.get("label"): + return column.get("label") + return column.get("sqlExpression", None) + + +def get_x_axis_label(form_data: dict[str, Any]) -> Optional[str]: + """Return the x_axis label from form_data.""" + if col := get_x_axis_column(form_data): + return get_column_label(col) + return None + + +def time_compare_pivot_operator( + form_data: dict[str, Any], query_object: dict[str, Any] +) -> Optional[dict[str, Any]]: + """ + A post-processing factory function for pivot operations. + + Args: + form_data: The form data containing configuration + query_object: The query object with series and columns information + + Returns: + Dictionary with pivot operation configuration or None + """ + metric_offset_map = get_metric_offsets_map(form_data, query_object) + x_axis_label = get_x_axis_label(form_data) + columns = ( + query_object.get("series_columns") + if query_object.get("series_columns") is not None + else query_object.get("columns") + ) + + if is_time_comparison(form_data, query_object) and x_axis_label: + # Create aggregates dictionary from metric offset map + metrics = list(metric_offset_map.values()) + list(metric_offset_map.keys()) + aggregates = { + metric: {"operator": "mean"} # use 'mean' aggregates to avoid dropping NaN + for metric in metrics + } + + return { + "operation": "pivot", + "options": { + "index": [x_axis_label], + "columns": [get_column_label(col) for col in ensure_is_array(columns)], + "drop_missing_columns": not form_data.get("show_empty_columns"), + "aggregates": aggregates, + }, + } + + return None + + +def pivot_operator( + form_data: dict[str, Any], query_object: dict[str, Any] +) -> Optional[dict[str, Any]]: + """ + Construct a pivot operator configuration for post-processing. + + This function extracts metric labels (including extra metrics) from the query object + and form data, and retrieves the x-axis label. If both an x-axis label and at + least one metric label are present, it builds a pivot configuration that sets + the index as the x-axis label, transforms the columns via get_column_label, + and creates dummy 'mean' aggregates for each metric. + + Args: + form_data (dict): The form data containing query parameters. + query_object (dict): The base query object containing metrics + and column information. + + Returns: + dict or None: A dict with the pivot operator configuration + if the conditions are met, + otherwise None. + """ + metric_labels = [ + *ensure_is_array(query_object.get("metrics", [])), + *extract_extra_metrics(form_data), + ] + metric_labels = [get_metric_label(metric) for metric in metric_labels] + x_axis_label = get_x_axis_label(form_data) + columns = ( + query_object.get("series_columns") + if query_object.get("series_columns") is not None + else query_object.get("columns") + ) + + if x_axis_label and metric_labels: + cols_list = [get_column_label(col) for col in ensure_is_array(columns)] + return { + "operation": "pivot", + "options": { + "index": [x_axis_label], + "columns": cols_list, + # Create 'dummy' mean aggregates to assign cell values in pivot table + # using the 'mean' aggregates to avoid dropping NaN values + "aggregates": { + metric: {"operator": "mean"} for metric in metric_labels + }, + "drop_missing_columns": not form_data.get("show_empty_columns"), + }, + } + + return None + + +def normalize_order_by(query_object: dict[str, Any]) -> dict[str, Any]: + """ + Normalize the orderby clause in the query object. + + If the "orderby" key already contains a valid clause (a list whose first element + is a list of two elements, where the first element is truthy and the second a bool), + the original query_object is returned. Otherwise, the function creates a copy of + query_object, removes invalid orderby-related keys, and sets an orderby clause based + on available keys: "series_limit_metric", "legacy_order_by", or the first metric in + the "metrics" list. The sorting order is determined by the negation of "order_desc". + + Args: + query_object (dict): The query object containing orderby and related keys. + + Returns: + dict: A modified query object with a normalized "orderby" clause. + """ + if ( + isinstance(query_object.get("orderby"), list) + and len(query_object.get("orderby", [])) > 0 + ): + # ensure a valid orderby clause + orderby_clause = query_object["orderby"][0] + if ( + isinstance(orderby_clause, list) + and len(orderby_clause) == 2 + and orderby_clause[0] + and isinstance(orderby_clause[1], bool) + ): + return query_object + + # remove invalid orderby keys from a copy + clone_query_object = query_object.copy() + clone_query_object.pop("series_limit_metric", None) + clone_query_object.pop("legacy_order_by", None) + clone_query_object.pop("order_desc", None) + clone_query_object.pop("orderby", None) + + is_asc = not query_object.get("order_desc", False) + + if query_object.get("series_limit_metric") is not None and query_object.get( + "series_limit_metric" + ): + return { + **clone_query_object, + "orderby": [[query_object["series_limit_metric"], is_asc]], + } + + # todo: Removed `legacy_order_by` after refactoring + if query_object.get("legacy_order_by") is not None and query_object.get( + "legacy_order_by" + ): + return { + **clone_query_object, + "orderby": [[query_object["legacy_order_by"], is_asc]], + } + + if ( + isinstance(query_object.get("metrics"), list) + and len(query_object.get("metrics", [])) > 0 + ): + return {**clone_query_object, "orderby": [[query_object["metrics"][0], is_asc]]} + + return clone_query_object + + +def remove_duplicates(items: Any, hash_func: Any = None) -> list[Any]: + """ + Remove duplicate items from a list. + + Args: + items: List of items to deduplicate + hash_func: Optional function to generate a hash for comparison + + Returns: + List with duplicates removed + """ + if hash_func: + seen = set() + result = [] + for x in items: + item_hash = hash_func(x) + if item_hash not in seen: + seen.add(item_hash) + result.append(x) + return result + else: + # Using Python's built-in uniqueness for lists + return list(dict.fromkeys(items)) # Preserves order in Python 3.7+ + + +def extract_fields_from_form_data( + rest_form_data: dict[str, Any], + query_field_aliases: dict[str, Any], + query_mode: Any | str, +) -> tuple[list[Any], list[Any], list[Any]]: + """ + Extract fields from form data based on aliases and query mode. + + Args: + rest_form_data (dict): The residual form data. + query_field_aliases (dict): A mapping of key aliases. + query_mode (str): The query mode, e.g. 'aggregate' or 'raw'. + + Returns: + tuple: A tuple of three lists: (columns, metrics, orderby) + """ + columns = [] + metrics = [] + orderby = [] + + for key, value in rest_form_data.items(): + if value is None: + continue + + normalized_key = query_field_aliases.get(key, key) + + if query_mode == "aggregate" and normalized_key == "columns": + continue + if query_mode == "raw" and normalized_key in ["groupby", "metrics"]: + continue + + if normalized_key == "groupby": + normalized_key = "columns" + + if normalized_key == "metrics": + metrics.extend(value if isinstance(value, list) else [value]) + elif normalized_key == "columns": + columns.extend(value if isinstance(value, list) else [value]) + elif normalized_key == "orderby": + orderby.extend(value if isinstance(value, list) else [value]) + + return columns, metrics, orderby + + +def extract_query_fields( + form_data: dict[Any, Any], aliases: Any = None +) -> Union[dict[str, Any]]: + """ + Extract query fields from form data. + + Args: + form_data: Form data residual + aliases: Query field aliases + + Returns: + Dictionary with columns, metrics, and orderby fields + """ + query_field_aliases = { + "metric": "metrics", + "metric_2": "metrics", + "secondary_metric": "metrics", + "x": "metrics", + "y": "metrics", + "size": "metrics", + "all_columns": "columns", + "series": "groupby", + "order_by_cols": "orderby", + } + + if aliases: + query_field_aliases.update(aliases) + query_mode = form_data.pop("query_mode", None) + rest_form_data = form_data + + columns, metrics, orderby = extract_fields_from_form_data( + rest_form_data, query_field_aliases, query_mode + ) + + result: dict[str, Any] = { + "columns": remove_duplicates( + [col for col in columns if col != ""], get_column_label + ), + "orderby": None, + } + if query_mode != "raw": + result["metrics"] = remove_duplicates(metrics, get_metric_label) + else: + result["metrics"] = None + if orderby: + result["orderby"] = [] + for item in orderby: + if isinstance(item, str): + try: + result["orderby"].append(json.loads(item)) + except Exception as err: + raise ValueError("Found invalid orderby options") from err + else: + result["orderby"].append(item) + + return result + + +def extract_extras(form_data: dict[str, Any]) -> dict[str, Any]: + """ + Extract extras from the form_data analogous to the TS version. + """ + applied_time_extras: dict[str, Any] = {} + filters: list[Any] = [] + extras: dict[str, Any] = {} + extract: dict[str, Any] = { + "filters": filters, + "extras": extras, + "applied_time_extras": applied_time_extras, + } + + # Mapping reserved columns to query field names + reserved_columns_to_query_field = { + "__time_range": "time_range", + "__time_col": "granularity_sqla", + "__time_grain": "time_grain_sqla", + "__granularity": "granularity", + } + + extra_filters = form_data.get("extra_filters", []) + for filter_item in extra_filters: + col = filter_item.get("col") + # Check if filter col is reserved + if col in reserved_columns_to_query_field: + query_field = reserved_columns_to_query_field[col] + # Assign the filter value to the extract dict + extract[query_field] = filter_item.get("val") + applied_time_extras[col] = filter_item.get("val") + else: + filters.append(filter_item) + + # SQL: set extra properties based on TS logic + if "time_grain_sqla" in form_data.keys() or "time_grain_sqla" in extract.keys(): + # If time_grain_sqla is set in form_data, use it + # Otherwise, use the value from extract + value = form_data.get("time_grain_sqla") or form_data.get("time_grain_sqla") + extras["time_grain_sqla"] = value + + extract["granularity"] = ( + extract.get("granularity_sqla") + or form_data.get("granularity") + or form_data.get("granularity_sqla") + ) + # Remove temporary keys + extract.pop("granularity_sqla", None) + extract.pop("time_grain_sqla", None) + if extract["granularity"] is None: + extract.pop("granularity", None) + + return extract + + +def is_defined(x: Any) -> bool: + """ + Returns True if x is not None. + This is equivalent to checking that x is neither null nor undefined in TypeScript. + """ + return x is not None + + +def sanitize_clause(clause: str) -> str: + """ + Sanitize a SQL clause. If the clause contains '--', append a newline. + Then wrap the clause in parentheses. + """ + if clause is None: + return "" + sanitized_clause = clause + if "--" in clause: + sanitized_clause = clause + "\n" + return f"({sanitized_clause})" + + +def is_unary_operator(operator: Any | str) -> bool: + """Return True if operator is unary.""" + return operator in unary_operator_set + + +def is_binary_operator(operator: Any | str) -> bool: + """Return True if operator is binary.""" + return operator in binary_operator_set + + +def is_set_operator(operator: Any | str) -> bool: + """Return True if operator is a set operator.""" + return operator in set_operator_set + + +def is_unary_adhoc_filter(filter_item: dict[str, Any]) -> bool: + """Return True if the filter's operator is unary.""" + return is_unary_operator(filter_item.get("operator")) + + +def is_binary_adhoc_filter(filter_item: dict[str, Any]) -> bool: + """Return True if the filter's operator is binary.""" + return is_binary_operator(filter_item.get("operator")) + + +def convert_filter(filter_item: dict[str, Any]) -> dict[str, Any]: + """Convert an adhoc filter to a query clause dict.""" + subject = filter_item.get("subject") + if is_unary_adhoc_filter(filter_item): + operator = filter_item.get("operator") + return {"col": subject, "op": operator} + if is_binary_adhoc_filter(filter_item): + operator = filter_item.get("operator") + val = filter_item.get("comparator") + result = {"col": subject, "op": operator} + if val is not None: + result["val"] = val + return result + operator = filter_item.get("operator") + val = filter_item.get("comparator") + result = {"col": subject, "op": operator} + if val is not None: + result["val"] = val + return result + + +def is_simple_adhoc_filter(filter_item: dict[str, Any]) -> bool: + """Return True if the filter is a simple adhoc filter.""" + return filter_item.get("expressionType") == "SIMPLE" + + +def process_filters(form_data: dict[str, Any]) -> dict[str, Any]: + """ + Process filters from form_data: + - Split adhoc_filters according to clause and expression type. + - Build simple filter and freeform SQL clauses for WHERE/HAVING. + - Place freeform clauses into extras. + """ + adhoc_filters = form_data.get("adhoc_filters", []) + extras = form_data.get("extras", {}) + filters_list = form_data.get("filters", []) + + # Copy filters_list into simple_where + simple_where = filters_list[:] + freeform_where = [] + freeform_having = [] + + if where := form_data.get("where"): + freeform_where.append(where) + + for filter_item in adhoc_filters: + clause = filter_item.get("clause") + if is_simple_adhoc_filter(filter_item): + filter_clause = convert_filter(filter_item) + if clause == "WHERE": + simple_where.append(filter_clause) + else: + sql_expression = filter_item.get("sqlExpression") + if clause == "WHERE": + freeform_where.append(sql_expression) + else: + freeform_having.append(sql_expression) + + extras["having"] = " AND ".join([sanitize_clause(s) for s in freeform_having]) + extras["where"] = " AND ".join([sanitize_clause(s) for s in freeform_where]) + + return { + "filters": simple_where, + "extras": extras, + } + + +def override_extra_form_data( + query_object: dict[str, Any], override_form_data: dict[str, Any] +) -> dict[str, Any]: + """ + Override parts of the query_object with values from override_form_data. + + Mimics the behavior of the TypeScript function: + - For keys in EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS, + if set in override_form_data, assign the value in query_object + under the mapped target key. + - For keys in EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS, + if present in override_form_data, add them to query_object['extras']. + """ + # Create a copy of the query object + overridden_form_data = query_object.copy() + # Ensure extras is a mutable copy of what's in query_object (or an empty dict) + overridden_extras = overridden_form_data.get("extras", {}).copy() + + # Process regular mappings + for key, target in EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS.items(): + value = override_form_data.get(key) + if value is not None: + overridden_form_data[target] = value + + # Process extra keys + for key in EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS: + if key in override_form_data: + overridden_extras[key] = override_form_data[key] + + if overridden_extras: + overridden_form_data["extras"] = overridden_extras + + return overridden_form_data + + +def build_query_object( + form_data: dict[str, Any], query_fields: Any = None +) -> dict[str, Any]: + """ + Build a query object from form data. + + Args: + form_data: Dictionary containing form data + query_fields: Optional query field aliases + + Returns: + Dictionary representing the query object + """ + # Extract fields from form_data with defaults + annotation_layers = form_data.get("annotation_layers", []) + extra_form_data = form_data.get("extra_form_data", {}) + time_range = form_data.get("time_range") + since = form_data.get("since") + until = form_data.get("until") + row_limit = form_data.get("row_limit") + row_offset = form_data.get("row_offset") + order_desc = form_data.get("order_desc") + limit: Any | int = form_data.get("limit") + timeseries_limit_metric = form_data.get("timeseries_limit_metric") + granularity = form_data.get("granularity") + url_params = form_data.get("url_params", {}) + custom_params = form_data.get("custom_params", {}) + series_columns = form_data.get("series_columns") + series_limit: Any | str = form_data.get("series_limit") + series_limit_metric = form_data.get("series_limit_metric") + + # Create residual_form_data by removing extracted fields + residual_form_data = { + k: v + for k, v in form_data.items() + if k + not in [ + "annotation_layers", + "extra_form_data", + "time_range", + "since", + "until", + "row_limit", + "row_offset", + "order_desc", + "limit", + "timeseries_limit_metric", + "granularity", + "url_params", + "custom_params", + "series_columns", + "series_limit", + "series_limit_metric", + ] + } + + # Extract fields from extra_form_data + append_adhoc_filters = ( + extra_form_data.get("adhoc_filters", []) if extra_form_data else [] + ) + append_filters = extra_form_data.get("filters", []) if extra_form_data else [] + custom_form_data = ( + extra_form_data.get("custom_form_data", {}) if extra_form_data else {} + ) + overrides = ( + { + k: v + for k, v in extra_form_data.items() + if k not in ["adhoc_filters", "filters", "custom_form_data"] + } + if extra_form_data + else {} + ) + + # Convert to numeric values + numeric_row_limit: Any = float(row_limit) if row_limit is not None else None + numeric_row_offset: Any = float(row_offset) if row_offset is not None else None + + # Extract query fields + extracted_fields = extract_query_fields(residual_form_data, query_fields) + metrics = extracted_fields.get("metrics") + columns = extracted_fields.get("columns") + orderby = extracted_fields.get("orderby") + + # Collect and process filters + extras = extract_extras(form_data) + extra_filters = extras.get("filters", []) + filter_form_data = { + "filters": extra_filters + append_filters, + "adhoc_filters": (form_data.get("adhoc_filters") or []) + append_adhoc_filters, + } + extras_and_filters = process_filters({**form_data, **extras, **filter_form_data}) + + def normalize_series_limit_metric(metric: Any) -> Optional[Any]: + if is_query_form_metric(metric): + return metric + return None + + # Build the query object + query_object: dict[Any, Any] = { + **extras, + **extras_and_filters, + "columns": columns, + "metrics": metrics, + "orderby": orderby, + "annotation_layers": annotation_layers, + "series_columns": series_columns, + "row_limit": ( + None + if row_limit is None or math.isnan(numeric_row_limit) + else int(numeric_row_limit) + ), + "series_limit": ( + series_limit + if series_limit is not None + else (int(limit) if is_defined(limit) else 0) + ), + "order_desc": True if order_desc is None else order_desc, + "url_params": url_params, + "custom_params": custom_params, + } + + row_offset = ( + None + if row_offset is None or math.isnan(numeric_row_offset) + else numeric_row_offset + ) + + temp = normalize_series_limit_metric(series_limit_metric) + series_limit_metric = temp if temp is not None else timeseries_limit_metric + + for key, value in [ + ("time_range", time_range), + ("since", since), + ("until", until), + ("granularity", granularity), + ("series_limit_metric", series_limit_metric), + ("row_offset", row_offset), + ]: + if value is not None: + query_object[key] = value + + # Override extra form data + query_object = override_extra_form_data(query_object, overrides) + + query_object = {k: v for k, v in query_object.items() if v is not None} + + # Return the final query object with custom form data + return {**query_object, "custom_form_data": custom_form_data} + + +def omit(d: dict[str, Any], keys: list[Any]) -> dict[str, Any]: + """ + Return a copy of dictionary d without the specified keys. + """ + return {k: v for k, v in d.items() if k not in keys} + + +def normalize_time_column( + form_data: dict[str, Any], query_object: dict[str, Any] +) -> dict[str, Any]: + """ + If x_axis is set in form_data, find its index in query_object's columns and update + that column + with timeGrain and columnType information. The updated query_object omits + the 'is_timeseries' key. + """ + if not is_x_axis_set(form_data): + return query_object + + _columns: list[Any] = query_object.get("columns", []) + _extras = query_object.get("extras", {}) + # Create a shallow copy of columns + mutated_columns = list(_columns) + x_axis: Any = form_data.get("x_axis") + axis_idx = None + + # Find the index of the x_axis in the columns list + for idx, col in enumerate(_columns): + if ( + is_physical_column(col) and is_physical_column(x_axis) and col == x_axis + ) or ( + is_adhoc_column(col) + and is_adhoc_column(x_axis) + and col.get("sqlExpression") == x_axis.get("sqlExpression") + ): + axis_idx = idx + break + + if axis_idx is not None and axis_idx > -1 and x_axis and isinstance(_columns, list): + if is_adhoc_column(_columns[axis_idx]): + # Update the adhoc column with additional keys. + updated = dict(_columns[axis_idx]) + updated["columnType"] = "BASE_AXIS" + if _extras: + if "time_grain_sqla" in _extras.keys(): + updated["timeGrain"] = _extras["time_grain_sqla"] + mutated_columns[axis_idx] = updated + + else: + # For physical columns, create a new column entry. + mutated_columns[axis_idx] = { + "columnType": "BASE_AXIS", + "sqlExpression": x_axis, + "label": x_axis, + "expressionType": "SQL", + } + if _extras: + if "time_grain_sqla" in _extras.keys(): + mutated_columns[axis_idx]["timeGrain"] = _extras["time_grain_sqla"] + + # Create a new query object without the 'is_timeseries' key. + new_query_object = omit(query_object, ["is_timeseries"]) + new_query_object["columns"] = mutated_columns + return new_query_object + + # Fallback: return the original query_object + return query_object + + +def build_query_context( + form_data: dict[str, Any], options: Any = None +) -> dict[str, Any]: + # Handle options based on type + def default_build_query(x: Any) -> list[Any]: + return [x] + + if callable(options): + query_fields = {} + build_query = options + elif options: + query_fields = options.get("query_fields", {}) + build_query = options.get("build_query", lambda x: [x]) + else: + query_fields = {} + build_query = default_build_query + + queries = build_query(build_query_object(form_data, query_fields)) + + for query in queries: + if isinstance(query.get("post_processing"), list): + query["post_processing"] = [p for p in query["post_processing"] if p] + + if is_x_axis_set(form_data): + queries = [normalize_time_column(form_data, query) for query in queries] + + return { + "datasource": DatasourceKey(form_data["datasource"]).to_object(), + "force": form_data.get("force", False), + "queries": queries, + "form_data": form_data, + "result_format": form_data.get("result_format", "json"), + "result_type": form_data.get("result_type", "full"), + } + + +def rolling_window_operator( + form_data: dict[str, Any], query_object: dict[str, Any] +) -> Optional[dict[str, Any]]: + """ + Builds a post-processing configuration for a rolling window. + + - If it's a time comparison, compute the columns from the metric offsets map. + - Otherwise, derive the columns from query_object.metrics. + - Then, based on the rolling_type, return a configuration dict. + """ + # Determine the columns to operate on + if is_time_comparison(form_data, query_object): + metrics_map = get_metric_offsets_map(form_data, query_object) + columns = list(metrics_map.values()) + list(metrics_map.keys()) + else: + metrics = ensure_is_array(query_object.get("metrics")) + columns = [] + for metric in metrics: + if isinstance(metric, str): + columns.append(metric) + elif isinstance(metric, dict): + columns.append(metric.get("label")) + + # Build a columns map from the list of columns + columns_map = {col: col for col in columns if col is not None} + + # Determine the operation based on rolling_type + rolling_type = form_data.get("rolling_type") + + if rolling_type == RollingType.Cumsum.value: + return { + "operation": "cum", + "options": { + "operator": "sum", + "columns": columns_map, + }, + } + + if rolling_type in [ + RollingType.Sum.value, + RollingType.Mean.value, + RollingType.Std.value, + ]: + return { + "operation": "rolling", + "options": { + "rolling_type": rolling_type, + "window": ensure_is_int(form_data.get("rolling_periods"), 1), + "min_periods": ensure_is_int(form_data.get("min_periods"), 0), + "columns": columns_map, + }, + } + + return None + + +def time_compare_operator( + form_data: Dict[str, Any], query_object: Dict[str, Any] +) -> Optional[Dict[str, Any]]: + """ + Returns a post-processing configuration for time comparison if applicable. + + If time comparison is enabled and the comparison type is not 'values', + builds a configuration dict that specifies the operation and options. + """ + comparison_type = form_data.get("comparison_type") + metric_offset_map = get_metric_offsets_map(form_data, query_object) + + if ( + is_time_comparison(form_data, query_object) + and comparison_type != ComparisonType.Values.value + ): + return { + "operation": "compare", + "options": { + "source_columns": list(metric_offset_map.values()), + "compare_columns": list(metric_offset_map.keys()), + "compare_type": comparison_type, + "drop_original_columns": True, + }, + } + return None + + +def resample_operator( + form_data: dict[str, Any], query_object: dict[str, Any] +) -> Any | dict[str, Any]: + """ + Returns a post-processing configuration for resampling if the required + resample_method and resample_rule are provided in form_data. + """ + resample_zero_fill = form_data.get("resample_method") == "zerofill" + resample_method = ( + "asfreq" if resample_zero_fill else form_data.get("resample_method") + ) + resample_rule = form_data.get("resample_rule") + + if resample_method and resample_rule: + return { + "operation": "resample", + "options": { + "method": resample_method, + "rule": resample_rule, + "fill_value": 0 if resample_zero_fill else None, + }, + } + return None + + +def rename_operator( + form_data: dict[str, Any], query_object: dict[str, Any] +) -> Optional[dict[str, Any]]: + """ + Produces a post-processing configuration to rename columns based on the criteria: + 1) Only one metric exists. + 2) There is at least one dimension (series_columns or columns). + 3) An x-axis label exists. + 4) If time comparison is enabled and its comparison type is not one of + [difference, ratio, percentage]. + 5) The form data contains a truthy 'truncate_metric' flag. + + Additionally, if time comparison is active and the comparison type is 'values', + the operator renames the metric with the corresponding offset label. + """ + metrics: Any = ensure_is_array(query_object.get("metrics")) + columns = ensure_is_array( + query_object.get("series_columns") + if query_object.get("series_columns") is not None + else query_object.get("columns") + ) + truncate_metric = form_data.get("truncate_metric") + x_axis_label = get_x_axis_label(form_data) + + # Check conditions for renaming + if ( + len(metrics) == 1 + and len(columns) > 0 + and x_axis_label + and not ( + is_time_comparison(form_data, query_object) + and form_data.get("comparison_type") + in { + ComparisonType.Difference.value, + ComparisonType.Ratio.value, + ComparisonType.Percentage.value, + } + ) + and truncate_metric is not None + and bool(truncate_metric) + ): + rename_pairs: Any = [] + + if ( + is_time_comparison(form_data, query_object) + and form_data.get("comparison_type") == ComparisonType.Values.value + ): + metric_offset_map = get_metric_offsets_map(form_data, query_object) + time_offsets = ensure_is_array(form_data.get("time_compare")) + for metric_with_offset in list(metric_offset_map.keys()): + offset_label = next( + (offset for offset in time_offsets if offset in metric_with_offset), + None, + ) + rename_pairs.append((metric_with_offset, offset_label)) + + rename_pairs.append((get_metric_label(metrics[0]), None)) + + return { + "operation": "rename", + "options": { + "columns": dict(rename_pairs), + "level": 0, + "inplace": True, + }, + } + + return None + + +def contribution_operator( + form_data: dict[str, Any], query_object: dict[str, Any], time_shifts: Any +) -> Optional[dict[str, Any]]: + """ + Returns a post-processing configuration for contribution if + form_data.contributionMode is truthy. + """ + if form_data.get("contributionMode"): + return { + "operation": "contribution", + "options": { + "orientation": form_data.get("contributionMode"), + "time_shifts": time_shifts, + }, + } + return None + + +def sort_operator( + form_data: Dict[str, Any], query_object: Dict[str, Any] +) -> Optional[Dict[str, Any]]: + """ + Build a sort post-processing configuration if the conditions are met. + + Conditions: + - form_data.x_axis_sort and form_data.x_axis_sort_asc are defined. + - The sort key exists in sortableLabels. + - groupby is empty. + + If the sort key matches the x-axis label, sort using the index. + Otherwise, sort by the provided sort key. + """ + # Build the sortable labels list + sortable_labels: list[Any] = [ + get_x_axis_label(form_data), + ] + sortable_labels += [ + get_metric_label(m) for m in ensure_is_array(form_data.get("metrics")) + ] + sortable_labels += [get_metric_label(m) for m in extract_extra_metrics(form_data)] + # Filter out any falsy values + sortable_labels = [label for label in sortable_labels if label] + + # Check the required conditions. + if ( + is_defined(form_data.get("x_axis_sort")) + and is_defined(form_data.get("x_axis_sort_asc")) + and form_data.get("x_axis_sort") in sortable_labels + and is_empty(form_data.get("groupby")) ## + ): + if form_data.get("x_axis_sort") == get_x_axis_label(form_data): + return { + "operation": "sort", + "options": { + "is_sort_index": True, + "ascending": form_data.get("x_axis_sort_asc"), + }, + } + return { + "operation": "sort", + "options": { + "by": form_data.get("x_axis_sort"), + "ascending": form_data.get("x_axis_sort_asc"), + }, + } + return None + + +def flatten_operator( + form_data: dict[str, Any], query_object: dict[str, Any] +) -> dict[str, Any]: + """ + Returns a post-processing configuration that indicates a flatten operation. + """ + return {"operation": "flatten"} + + +def prophet_operator( + form_data: dict[str, Any], query_object: dict[str, Any] +) -> Any | dict[str, Any]: + """ + Returns a post-processing configuration for prophet forecasting + if forecast is enabled and an x-axis label is present. + """ + x_axis_label = get_x_axis_label(form_data) + if form_data.get("forecastEnabled") and x_axis_label: + try: + periods = int(form_data.get("forecastPeriods", 0)) + except (TypeError, ValueError): + periods = 0 + try: + confidence_interval = float(form_data.get("forecastInterval", 0)) + except (TypeError, ValueError): + confidence_interval = 0.0 + + return { + "operation": "prophet", + "options": { + "time_grain": form_data.get("time_grain_sqla"), + "periods": periods, + "confidence_interval": confidence_interval, + "yearly_seasonality": form_data.get("forecastSeasonalityYearly"), + "weekly_seasonality": form_data.get("forecastSeasonalityWeekly"), + "daily_seasonality": form_data.get("forecastSeasonalityDaily"), + "index": x_axis_label, + }, + } + return None + + +def rank_operator( + form_data: dict[str, Any], query_object: dict[str, Any], options: dict[str, Any] +) -> dict[str, Any]: + """ + Returns a post-processing configuration for ranking. + + Args: + form_data (dict): The form data for the query. + query_object (dict): The base query object. + options (dict): Options for the rank operator. + + Returns: + dict: A configuration dict with the ranking operation. + """ + options_dict = options + if options_dict.get("group_by") is None: + options_dict.pop("group_by", None) + return { + "operation": "rank", + "options": options_dict, + } + + +def drop_none_values(options: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in options.items() if v is not None} + + +def histogram_operator( + form_data: dict[str, str | Any], query_object: dict[str, Any] +) -> dict[str, Any]: + """ + Build a histogram operator configuration. + + This function extracts histogram parameters from the form data and builds an + operator configuration for generating a histogram. It attempts to parse the + 'bins' value as an integer (defaulting to 5 if parsing fails), retrieves the + column and groupby details by using get_column_label, and collects additional + options such as cumulative and normalize flags. + + Args: + form_data (dict): Dictionary containing histogram parameters + such as 'bins', 'column','cumulative', 'groupby', and 'normalize'. + query_object (dict): Dictionary representing the query object + + Returns: + dict: A dictionary with keys "operation" and "options" + that defines the histogram operator. + """ + bins: Any | int = form_data.get("bins") + column = form_data.get("column") + cumulative = form_data.get("cumulative") + groupby = form_data.get("groupby", []) + normalize = form_data.get("normalize") + try: + parsed_bins = int(bins) + except (TypeError, ValueError): + parsed_bins = 5 + parsed_column = get_column_label(column) + parsed_groupby = [get_column_label(g) for g in groupby] + + options = { + "column": parsed_column, + "groupby": parsed_groupby, + "bins": parsed_bins, + "cumulative": cumulative, + "normalize": normalize, + } + + result = {"operation": "histogram", "options": drop_none_values(options)} + + return result + + +def retain_form_data_suffix( + form_data: dict[str, Any], control_suffix: str +) -> dict[str, Any]: + """ + Retain keys from the form data that end with a specified suffix + and remove the suffix from them. + + The function creates a new form data dictionary. For keys ending + with the provided + control_suffix, it removes the suffix and assigns the corresponding + value. If a key does + not end with the suffix and is not already set in the new dictionary + (i.e. via a suffixed key), it is retained as-is. + + Args: + form_data (dict): The original form data dictionary. + control_suffix (str): The suffix string to look for in keys. + + Returns: + dict: A new dictionary containing the retained and modified keys. + """ + new_form_data = {} + entries = sorted( + form_data.items(), + key=lambda kv: 1 if kv[0].endswith(control_suffix) else 0, + reverse=True, + ) + for key, value in entries: + if key.endswith(control_suffix): + new_form_data[key[: -len(control_suffix)]] = value + if not key.endswith(control_suffix) and key not in new_form_data.keys(): + new_form_data[key] = value + return new_form_data + + +def remove_form_data_suffix( + form_data: dict[str, Any], control_suffix: str +) -> dict[str, Any]: + """ + Remove keys from the form data that end with a specified suffix. + + This function builds a new dictionary containing only those key-value pairs + where the key does NOT end with the given control_suffix. + + Args: + form_data (dict): The original form data dictionary. + control_suffix (str): The suffix indicating which keys should be removed. + + Returns: + dict: A new dictionary with the keys ending with control_suffix removed. + """ + new_form_data = {} + for key, value in form_data.items(): + if not key.endswith(control_suffix): + new_form_data[key] = value + return new_form_data diff --git a/superset/migrations/shared/utils.py b/superset/migrations/shared/utils.py index a0a096168a8..a7f40b07453 100644 --- a/superset/migrations/shared/utils.py +++ b/superset/migrations/shared/utils.py @@ -506,9 +506,21 @@ def cast_text_column_to_json( conn.execute( text( f""" - ALTER TABLE {table} - ALTER COLUMN {column} TYPE jsonb - USING {column}::jsonb +CREATE OR REPLACE FUNCTION safe_to_jsonb(input text) + RETURNS jsonb + LANGUAGE plpgsql + IMMUTABLE +AS $$ +BEGIN + RETURN input::jsonb; +EXCEPTION WHEN invalid_text_representation THEN + RETURN NULL; +END; +$$; + +ALTER TABLE {table} +ALTER COLUMN {column} TYPE jsonb +USING safe_to_jsonb({column}); """ ) ) @@ -525,6 +537,13 @@ def cast_text_column_to_json( stmt_select = select(t.c[pk], t.c[column]).where(t.c[column].is_not(None)) for row_pk, value in conn.execute(stmt_select): + try: + json.loads(value) + except json.JSONDecodeError: + logger.warning( + f"Invalid JSON value in column {column} for {pk}={row_pk}: {value}" + ) + continue stmt_update = update(t).where(t.c[pk] == row_pk).values({tmp_column: value}) conn.execute(stmt_update) diff --git a/superset/models/core.py b/superset/models/core.py index 9378452bd85..1768aa41fdd 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -307,7 +307,6 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable with suppress(TypeError, json.JSONDecodeError): encrypted_config = json.loads(masked_encrypted_extra) try: - # pylint: disable=useless-suppression parameters = self.db_engine_spec.get_parameters_from_uri( # type: ignore masked_uri, encrypted_extra=encrypted_config, @@ -660,7 +659,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable schema: str | None = None, mutator: Callable[[pd.DataFrame], None] | None = None, ) -> pd.DataFrame: - sqls = self.db_engine_spec.parse_sql(sql) + script = SQLScript(sql, self.db_engine_spec.engine) with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: engine_url = engine.url @@ -677,8 +676,11 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable with self.get_raw_connection(catalog=catalog, schema=schema) as conn: cursor = conn.cursor() df = None - for i, sql_ in enumerate(sqls): - sql_ = self.mutate_sql_based_on_config(sql_, is_split=True) + for i, statement in enumerate(script.statements): + sql_ = self.mutate_sql_based_on_config( + statement.format(), + is_split=True, + ) _log_query(sql_) with event_logger.log_context( action="execute_sql", @@ -687,7 +689,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable ): self.db_engine_spec.execute(cursor, sql_, self) - rows = self.fetch_rows(cursor, i == len(sqls) - 1) + rows = self.fetch_rows(cursor, i == len(script.statements) - 1) if rows is not None: df = self.load_into_dataframe(cursor.description, rows) @@ -762,11 +764,19 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable ) def apply_limit_to_sql( - self, sql: str, limit: int = 1000, force: bool = False + self, + sql: str, + limit: int = 1000, + force: bool = False, ) -> str: - if self.db_engine_spec.allow_limit_clause: - return self.db_engine_spec.apply_limit_to_sql(sql, limit, self, force=force) - return self.db_engine_spec.apply_top_to_sql(sql, limit) + script = SQLScript(sql, self.db_engine_spec.engine) + statement = script.statements[-1] + current_limit = statement.get_limit_value() or float("inf") + + if limit < current_limit or force: + statement.set_limit_value(limit, self.db_engine_spec.limit_method) + + return script.format() def safe_sqlalchemy_uri(self) -> str: return self.sqlalchemy_uri diff --git a/superset/models/helpers.py b/superset/models/helpers.py index db05b415848..341585e5da5 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -17,6 +17,8 @@ # pylint: disable=too-many-lines """a collection of model-related helper classes and functions""" +from __future__ import annotations + import builtins import dataclasses import logging @@ -32,7 +34,6 @@ import numpy as np import pandas as pd import pytz import sqlalchemy as sa -import sqlparse import yaml from flask import g from flask_appbuilder import Model @@ -63,15 +64,12 @@ from superset.exceptions import ( ColumnNotFoundException, QueryClauseValidationException, QueryObjectValidationError, - SupersetParseError, SupersetSecurityException, ) from superset.extensions import feature_flag_manager from superset.jinja_context import BaseTemplateProcessor -from superset.sql.parse import SQLScript +from superset.sql.parse import SQLScript, SQLStatement from superset.sql_parse import ( - has_table_query, - insert_rls_in_predicate, sanitize_clause, ) from superset.superset_typing import ( @@ -111,9 +109,10 @@ ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"] def validate_adhoc_subquery( sql: str, - database_id: int, - engine: str, + database: Database, + catalog: str | None, default_schema: str, + engine: str, ) -> str: """ Check if adhoc SQL contains sub-queries or nested sub-queries with table. @@ -125,28 +124,23 @@ def validate_adhoc_subquery( :raise SupersetSecurityException if sql contains sub-queries or nested sub-queries with table """ - statements = [] - for statement in sqlparse.parse(sql): - try: - has_table = has_table_query(str(statement), engine) - except SupersetParseError: - has_table = True + from superset.sql_lab import apply_rls - if has_table: - if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): - raise SupersetSecurityException( - SupersetError( - error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, - message=_("Custom SQL fields cannot contain sub-queries."), - level=ErrorLevel.ERROR, - ) + parsed_statement = SQLStatement(sql, engine) + if parsed_statement.has_subquery(): + if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, + message=_("Custom SQL fields cannot contain sub-queries."), + level=ErrorLevel.ERROR, ) - # TODO (betodealmeida): reimplement with sqlglot - statement = insert_rls_in_predicate(statement, database_id, default_schema) + ) - statements.append(statement) + # enforce RLS rules in any relevant tables + apply_rls(database, catalog, default_schema, parsed_statement) - return ";\n".join(str(statement) for statement in statements) + return parsed_statement.format() def json_to_dict(json_str: str) -> dict[Any, Any]: @@ -784,7 +778,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods raise NotImplementedError() @property - def database(self) -> "Database": + def database(self) -> Database: raise NotImplementedError() @property @@ -839,9 +833,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if expression: expression = validate_adhoc_subquery( expression, - database_id, - engine, + self.database, + self.catalog, schema, + engine, ) try: expression = sanitize_clause(expression) @@ -1467,6 +1462,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods extras = extras or {} time_grain = extras.get("time_grain_sqla") + # DB-specifc quoting for identifiers + with self.database.get_sqla_engine() as engine: + quote = engine.dialect.identifier_preparer.quote + template_kwargs = { "columns": columns, "from_dttm": from_dttm.isoformat() if from_dttm else None, @@ -1515,6 +1514,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods columns_by_name: dict[str, "TableColumn"] = { col.column_name: col for col in self.columns } + quoted_columns_by_name = {quote(k): v for k, v in columns_by_name.items()} metrics_by_name: dict[str, "SqlMetric"] = { m.metric_name: m for m in self.metrics @@ -1636,15 +1636,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods else: selected = validate_adhoc_subquery( selected, - self.database_id, - self.database.backend, + self.database, + self.catalog, self.schema, + self.database.db_engine_spec.engine, ) outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) else: outer = self.adhoc_column_to_sqla( - col=selected, template_processor=template_processor + col=selected, + template_processor=template_processor, ) groupby_all_columns[outer.name] = outer if ( @@ -1658,23 +1660,24 @@ class ExploreMixin: # pylint: disable=too-many-public-methods _sql = selected["sqlExpression"] _column_label = selected["label"] elif isinstance(selected, str): - _sql = selected + _sql = quote(selected) _column_label = selected selected = validate_adhoc_subquery( _sql, - self.database_id, - self.database.backend, + self.database, + self.catalog, self.schema, + self.database.db_engine_spec.engine, ) select_exprs.append( self.convert_tbl_column_to_sqla_col( - columns_by_name[selected], + quoted_columns_by_name[selected], template_processor=template_processor, label=_column_label, ) - if isinstance(selected, str) and selected in columns_by_name + if selected in quoted_columns_by_name else self.make_sqla_column_compatible( literal_column(selected), _column_label ) @@ -1989,9 +1992,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods and db_engine_spec.allows_hidden_cc_in_orderby and col.name in [select_col.name for select_col in select_exprs] ): - with self.database.get_sqla_engine() as engine: - quote = engine.dialect.identifier_preparer.quote - col = literal_column(quote(col.name)) + col = literal_column(quote(col.name)) direction = sa.asc if ascending else sa.desc qry = qry.order_by(direction(col)) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 4d443423d18..f73334a1e99 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -56,7 +56,8 @@ from superset.models.helpers import ( ExtraJSONMixin, ImportExportMixin, ) -from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table +from superset.sql.parse import CTASMethod +from superset.sql_parse import extract_tables_from_jinja_sql, Table from superset.sqllab.limiting_factor import LimitingFactor from superset.utils import json from superset.utils.core import ( @@ -128,7 +129,7 @@ class Query( ) select_as_cta = Column(Boolean) select_as_cta_used = Column(Boolean, default=False) - ctas_method = Column(String(16), default=CtasMethod.TABLE) + ctas_method = Column(String(16), default=CTASMethod.TABLE.name) progress = Column(Integer, default=0) # 1..100 # # of rows in the result set or rows modified. diff --git a/superset/security/manager.py b/superset/security/manager.py index 90ea1515d6f..84d31971ae1 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -263,7 +263,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ADMIN_ONLY_VIEW_MENUS = { "Access Requests", - "Action Log", + "Action Logs", "Log", "List Users", "UsersListView", diff --git a/superset/sql/parse.py b/superset/sql/parse.py index dc9ca632bba..73255bef13e 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -27,12 +27,9 @@ from dataclasses import dataclass from typing import Any, Generic, TypeVar import sqlglot -import sqlparse -from deprecation import deprecated from sqlglot import exp from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError -from sqlglot.expressions import Func from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope @@ -58,7 +55,7 @@ SQLGLOT_DIALECTS = { # "db2": ??? # "dremio": ??? "drill": Dialects.DRILL, - # "druid": ??? + "druid": Dialects.DRUID, "duckdb": Dialects.DUCKDB, # "dynamodb": ??? # "elasticsearch": ??? @@ -99,6 +96,167 @@ SQLGLOT_DIALECTS = { } +class LimitMethod(enum.Enum): + """ + Limit methods. + + This is used to determine how to add a limit to a SQL statement. + """ + + FORCE_LIMIT = enum.auto() + WRAP_SQL = enum.auto() + FETCH_MANY = enum.auto() + + +class CTASMethod(enum.Enum): + TABLE = enum.auto() + VIEW = enum.auto() + + +class RLSMethod(enum.Enum): + """ + Methods for enforcing RLS. + """ + + AS_PREDICATE = enum.auto() + AS_SUBQUERY = enum.auto() + + +class RLSTransformer: + """ + AST transformer to apply RLS rules. + """ + + def __init__( + self, + catalog: str | None, + schema: str | None, + rules: dict[Table, list[exp.Expression]], + ) -> None: + self.catalog = catalog + self.schema = schema + self.rules = rules + + def get_predicate(self, table_node: exp.Table) -> exp.Expression | None: + """ + Get the combined RLS predicate for a table. + """ + table = Table( + table_node.name, + table_node.db if table_node.db else self.schema, + table_node.catalog if table_node.catalog else self.catalog, + ) + if predicates := self.rules.get(table): + return ( + exp.And( + this=predicates[0], + expressions=predicates[1:], + ) + if len(predicates) > 1 + else predicates[0] + ) + + return None + + +class RLSAsPredicateTransformer(RLSTransformer): + """ + Apply Row Level Security role as a predicate. + + This transformer will apply any RLS predicates to the relevant tables. For example, + given the RLS rule: + + table: some_table + clause: id = 42 + + If a user subject to the rule runs the following query: + + SELECT foo FROM some_table WHERE bar = 'baz' + + The query will be modified to: + + SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42 + + This approach is probably less secure than using subqueries, so it's only used for + databases without support for subqueries. + """ + + def __call__(self, node: exp.Expression) -> exp.Expression: + if not isinstance(node, exp.Table): + return node + + predicate = self.get_predicate(node) + if not predicate: + return node + + # qualify columns with table name + for column in predicate.find_all(exp.Column): + column.set("table", node.alias or node.this) + + if isinstance(node.parent, exp.From): + select = node.parent.parent + if where := select.args.get("where"): + predicate = exp.And( + this=predicate, + expression=exp.Paren(this=where.this), + ) + select.set("where", exp.Where(this=predicate)) + + elif isinstance(node.parent, exp.Join): + join = node.parent + if on := join.args.get("on"): + predicate = exp.And( + this=predicate, + expression=exp.Paren(this=on), + ) + join.set("on", predicate) + + return node + + +class RLSAsSubqueryTransformer(RLSTransformer): + """ + Apply Row Level Security role as a subquery. + + This transformer will apply any RLS predicates to the relevant tables. For example, + given the RLS rule: + + table: some_table + clause: id = 42 + + If a user subject to the rule runs the following query: + + SELECT foo FROM some_table WHERE bar = 'baz' + + The query will be modified to: + + SELECT foo FROM (SELECT * FROM some_table WHERE id = 42) AS some_table + WHERE bar = 'baz' + + This approach is probably more secure than using predicates, but it doesn't work for + all databases. + """ + + def __call__(self, node: exp.Expression) -> exp.Expression: + if not isinstance(node, exp.Table): + return node + + if predicate := self.get_predicate(node): + # use alias or name + alias = node.alias or node.sql() + node.set("alias", None) + node = exp.Subquery( + this=exp.Select( + expressions=[exp.Star()], + where=exp.Where(this=predicate), + **{"from": exp.From(this=node.copy())}, + ), + alias=alias, + ) + + return node + + @dataclass(eq=True, frozen=True) class Table: """ @@ -155,12 +313,17 @@ class BaseSQLStatement(Generic[InternalRepresentation]): def __init__( self, - statement: str, - engine: str, + statement: str | None = None, + engine: str = "base", ast: InternalRepresentation | None = None, ): - self._sql = statement - self._parsed = ast or self._parse_statement(statement, engine) + if ast: + self._parsed = ast + elif statement: + self._parsed = self._parse_statement(statement, engine) + else: + raise ValueError("Either statement or ast must be provided") + self.engine = engine self.tables = self._extract_tables_from_statement(self._parsed, self.engine) @@ -223,6 +386,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def is_select(self) -> bool: + """ + Check if the statement is a `SELECT` statement. + """ + raise NotImplementedError() + def is_mutating(self) -> bool: """ Check if the statement mutates data (DDL/DML). @@ -237,6 +406,95 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def check_functions_present(self, functions: set[str]) -> bool: + """ + Check if any of the given functions are present in the script. + + :param functions: List of functions to check for + :return: True if any of the functions are present + """ + raise NotImplementedError() + + def get_limit_value(self) -> int | None: + """ + Get the limit value of the statement. + """ + raise NotImplementedError() + + def set_limit_value( + self, + limit: int, + method: LimitMethod = LimitMethod.FORCE_LIMIT, + ) -> None: + """ + Add a limit to the statement. + """ + raise NotImplementedError() + + def has_cte(self) -> bool: + """ + Check if the statement has a CTE. + + :return: True if the statement has a CTE at the top level. + """ + raise NotImplementedError() + + def as_cte(self, alias: str = "__cte") -> BaseSQLStatement[InternalRepresentation]: + """ + Rewrite the statement as a CTE. + + :param alias: The alias to use for the CTE. + :return: A new BaseSQLStatement[InternalRepresentation] with the CTE. + """ + raise NotImplementedError() + + def as_create_table( + self, + table: Table, + method: CTASMethod, + ) -> BaseSQLStatement[InternalRepresentation]: + """ + Rewrite the statement as a `CREATE TABLE AS` statement. + + :param table: The table to create. + :param method: The method to use for creating the table. + :return: A new BaseSQLStatement[InternalRepresentation] with the CTE. + """ + raise NotImplementedError() + + def has_subquery(self) -> bool: + """ + Check if the statement has a subquery. + + :return: True if the statement has a subquery at the top level. + """ + raise NotImplementedError() + + def parse_predicate(self, predicate: str) -> InternalRepresentation: + """ + Parse a predicate string into an AST. + + :param predicate: The predicate to parse. + :return: The parsed predicate. + """ + raise NotImplementedError() + + def apply_rls( + self, + catalog: str | None, + schema: str | None, + predicates: dict[Table, list[InternalRepresentation]], + method: RLSMethod, + ) -> None: + """ + Apply relevant RLS rules to the statement inplace. + + :param catalog: The default catalog for non-qualified table names + :param schema: The default schema for non-qualified table names + :param method: The method to use for applying the rules. + """ + raise NotImplementedError() + def __str__(self) -> str: return self.format() @@ -250,8 +508,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): def __init__( self, - statement: str, - engine: str, + statement: str | None = None, + engine: str = "base", ast: exp.Expression | None = None, ): self._dialect = SQLGLOT_DIALECTS.get(engine) @@ -264,7 +522,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): """ dialect = SQLGLOT_DIALECTS.get(engine) try: - return sqlglot.parse(script, dialect=dialect) + statements = sqlglot.parse(script, dialect=dialect) except sqlglot.errors.ParseError as ex: error = ex.errors[0] raise SupersetParseError( @@ -281,6 +539,20 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): message="Unable to parse script", ) from ex + # `sqlglot` will parse comments after the last semicolon as a separate + # statement; move them back to the last token in the last real statement + if len(statements) > 1 and isinstance(statements[-1], exp.Semicolon): + last_statement = statements.pop() + target = statements[-1] + for node in statements[-1].walk(): + if hasattr(node, "comments"): + target = node + + target.comments = target.comments or [] + target.comments.extend(last_statement.comments) + + return statements + @classmethod def split_script( cls, @@ -356,6 +628,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): dialect = SQLGLOT_DIALECTS.get(engine) return extract_tables_from_statement(parsed, dialect) + def is_select(self) -> bool: + """ + Check if the statement is a `SELECT` statement. + """ + return isinstance(self._parsed, exp.Select) + def is_mutating(self) -> bool: """ Check if the statement mutates data (DDL/DML). @@ -389,7 +667,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): and self._parsed.expression.name.upper().startswith("ANALYZE ") ): analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :] - return SQLStatement(analyzed_sql, self.engine).is_mutating() + return SQLStatement( + statement=analyzed_sql, + engine=self.engine, + ).is_mutating() return False @@ -397,34 +678,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): """ Pretty-format the SQL statement. """ - if self._dialect: - try: - write = Dialect.get_or_raise(self._dialect) - return write.generate( - self._parsed, - copy=False, - comments=comments, - pretty=True, - ) - except ValueError: - pass - - return self._fallback_formatting() - - @deprecated(deprecated_in="4.0") - def _fallback_formatting(self) -> str: - """ - Format SQL without a specific dialect. - - Reformatting SQL using the generic sqlglot dialect is known to break queries. - For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which - breaks the query for Firebolt. To avoid this, we use sqlparse for formatting - when the dialect is not known. - - In 5.0 we should remove `sqlparse`, and the method should return the query - unmodified. - """ - return sqlparse.format(self._sql, reindent=True, keyword_case="upper") + return Dialect.get_or_raise(self._dialect).generate( + self._parsed, + copy=True, + comments=comments, + pretty=True, + ) def get_settings(self) -> dict[str, str | bool]: """ @@ -447,12 +706,11 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): """ # only optimize statements that have a custom dialect if not self._dialect: - return SQLStatement(self._sql, self.engine, self._parsed.copy()) + return SQLStatement(ast=self._parsed.copy(), engine=self.engine) optimized = pushdown_predicates(self._parsed, dialect=self._dialect) - sql = optimized.sql(dialect=self._dialect) - return SQLStatement(sql, self.engine, optimized) + return SQLStatement(ast=optimized, engine=self.engine) def check_functions_present(self, functions: set[str]) -> bool: """ @@ -467,10 +725,136 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): if function.sql_name() != "ANONYMOUS" else function.name.upper() ) - for function in self._parsed.find_all(Func) + for function in self._parsed.find_all(exp.Func) } return any(function.upper() in present for function in functions) + def get_limit_value(self) -> int | None: + """ + Parse a SQL query and return the `LIMIT` or `TOP` value, if present. + """ + if limit_node := self._parsed.args.get("limit"): + literal = limit_node.args.get("expression") or getattr( + limit_node, "this", None + ) + if isinstance(literal, exp.Literal) and literal.is_int: + return int(literal.name) + + return None + + def set_limit_value( + self, + limit: int, + method: LimitMethod = LimitMethod.FORCE_LIMIT, + ) -> None: + """ + Modify the `LIMIT` or `TOP` value of the SQL statement inplace. + """ + if method == LimitMethod.FORCE_LIMIT: + self._parsed.args["limit"] = exp.Limit( + expression=exp.Literal(this=str(limit), is_string=False) + ) + elif method == LimitMethod.WRAP_SQL: + self._parsed = exp.Select( + expressions=[exp.Star()], + limit=exp.Limit( + expression=exp.Literal(this=str(limit), is_string=False) + ), + **{"from": exp.From(this=exp.Subquery(this=self._parsed.copy()))}, + ) + else: # method == LimitMethod.FETCH_MANY + pass + + def has_cte(self) -> bool: + """ + Check if the statement has a CTE. + + :return: True if the statement has a CTE at the top level. + """ + return "with" in self._parsed.args + + def as_cte(self, alias: str = "__cte") -> SQLStatement: + """ + Rewrite the statement as a CTE. + + This is needed by MS SQL when the query includes CTEs. In that case the CTEs + need to be moved to the top of the query when we wrap it as a subquery when + building charts. + + :param alias: The alias to use for the CTE. + :return: A new SQLStatement with the CTE. + """ + existing_ctes = self._parsed.args["with"].expressions if self.has_cte() else [] + self._parsed.args["with"] = None + new_cte = exp.CTE( + this=self._parsed.copy(), + alias=exp.TableAlias(this=exp.Identifier(this=alias)), + ) + return SQLStatement( + ast=exp.With(expressions=[*existing_ctes, new_cte], this=None), + engine=self.engine, + ) + + def as_create_table(self, table: Table, method: CTASMethod) -> SQLStatement: + """ + Rewrite the statement as a `CREATE TABLE AS` statement. + + :param table: The table to create. + :param method: The method to use for creating the table. + :return: A new SQLStatement with the create table statement. + """ + create_table = exp.Create( + this=sqlglot.parse_one(str(table), into=exp.Table), + kind=method.name, + expression=self._parsed.copy(), + ) + + return SQLStatement(ast=create_table, engine=self.engine) + + def has_subquery(self) -> bool: + """ + Check if the statement has a subquery. + + :return: True if the statement has a subquery at the top level. + """ + return bool(self._parsed.find(exp.Subquery)) + + def parse_predicate(self, predicate: str) -> exp.Expression: + """ + Parse a predicate string into an AST. + + :param predicate: The predicate to parse. + :return: The parsed predicate. + """ + return sqlglot.parse_one(predicate, dialect=self._dialect) + + def apply_rls( + self, + catalog: str | None, + schema: str | None, + predicates: dict[Table, list[exp.Expression]], + method: RLSMethod, + ) -> None: + """ + Apply relevant RLS rules to the statement inplace. + + :param catalog: The default catalog for non-qualified table names + :param schema: The default schema for non-qualified table names + :param method: The method to use for applying the rules. + """ + if not predicates: + return + + transformers = { + RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer, + RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer, + } + if method not in transformers: + raise ValueError(f"Invalid RLS method: {method}") + + transformer = transformers[method](catalog, schema, predicates) + self._parsed = self._parsed.transform(transformer) + class KQLSplitState(enum.Enum): """ @@ -486,48 +870,121 @@ class KQLSplitState(enum.Enum): INSIDE_MULTILINE_STRING = enum.auto() +class KQLTokenType(enum.Enum): + """ + Token types for KQL. + """ + + STRING = enum.auto() + WORD = enum.auto() + NUMBER = enum.auto() + SEMICOLON = enum.auto() + WHITESPACE = enum.auto() + OTHER = enum.auto() + + +def classify_non_string_kql(text: str) -> list[tuple[KQLTokenType, str]]: + """ + Classify non-string KQL. + """ + tokens: list[tuple[KQLTokenType, str]] = [] + for m in re.finditer(r"[A-Za-z_][A-Za-z_0-9]*|\d+|\s+|.", text): + tok = m.group(0) + if tok == ";": + tokens.append((KQLTokenType.SEMICOLON, tok)) + elif tok.isdigit(): + tokens.append((KQLTokenType.NUMBER, tok)) + elif re.match(r"[A-Za-z_][A-Za-z_0-9]*", tok): + tokens.append((KQLTokenType.WORD, tok)) + elif re.match(r"\s+", tok): + tokens.append((KQLTokenType.WHITESPACE, tok)) + else: + tokens.append((KQLTokenType.OTHER, tok)) + + return tokens + + +def tokenize_kql(kql: str) -> list[tuple[KQLTokenType, str]]: + """ + Turn a KQL script into a flat list of tokens. + """ + + state = KQLSplitState.OUTSIDE_STRING + tokens: list[tuple[KQLTokenType, str]] = [] + buffer = "" + script = kql + + for i, ch in enumerate(script): + if state == KQLSplitState.OUTSIDE_STRING: + if ch in {"'", '"'}: + if buffer: + tokens.extend(classify_non_string_kql(buffer)) + buffer = "" + state = ( + KQLSplitState.INSIDE_SINGLE_QUOTED_STRING + if ch == "'" + else KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING + ) + buffer = ch + elif ch == "`" and script[i - 2 : i] == "``": + if buffer: + tokens.extend(classify_non_string_kql(buffer)) + buffer = "" + state = KQLSplitState.INSIDE_MULTILINE_STRING + buffer = "`" + else: + buffer += ch + else: + buffer += ch + end_str = ( + ( + state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING + and ch == "'" + and script[i - 1] != "\\" + ) + or ( + state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING + and ch == '"' + and script[i - 1] != "\\" + ) + or ( + state == KQLSplitState.INSIDE_MULTILINE_STRING + and ch == "`" + and script[i - 2 : i] == "``" + ) + ) + if end_str: + tokens.append((KQLTokenType.STRING, buffer)) + buffer = "" + state = KQLSplitState.OUTSIDE_STRING + + if buffer: + tokens.extend(classify_non_string_kql(buffer)) + + return tokens + + def split_kql(kql: str) -> list[str]: """ - Custom function for splitting KQL statements. + Split a KQL script into statements on semicolons, + ignoring those inside strings. """ - statements = [] - state = KQLSplitState.OUTSIDE_STRING - statement_start = 0 - script = kql if kql.endswith(";") else kql + ";" - for i, character in enumerate(script): - if state == KQLSplitState.OUTSIDE_STRING: - if character == ";": - statements.append(script[statement_start:i]) - statement_start = i + 1 - elif character == "'": - state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING - elif character == '"': - state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING - elif character == "`" and script[i - 2 : i] == "``": - state = KQLSplitState.INSIDE_MULTILINE_STRING + tokens = tokenize_kql(kql) + stmts_tokens: list[list[tuple[KQLTokenType, str]]] = [] + current: list[tuple[KQLTokenType, str]] = [] - elif ( - state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING - and character == "'" - and script[i - 1] != "\\" - ): - state = KQLSplitState.OUTSIDE_STRING + for ttype, val in tokens: + if ttype == KQLTokenType.SEMICOLON: + if current: + stmts_tokens.append(current) + current = [] + else: + current.append((ttype, val)) - elif ( - state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING - and character == '"' - and script[i - 1] != "\\" - ): - state = KQLSplitState.OUTSIDE_STRING + if current: + stmts_tokens.append(current) - elif ( - state == KQLSplitState.INSIDE_MULTILINE_STRING - and character == "`" - and script[i - 2 : i] == "``" - ): - state = KQLSplitState.OUTSIDE_STRING - - return statements + return ["".join(val for _, val in stmt) for stmt in stmts_tokens] class KustoKQLStatement(BaseSQLStatement[str]): @@ -547,6 +1004,14 @@ class KustoKQLStatement(BaseSQLStatement[str]): details about it. """ + def __init__( + self, + statement: str | None = None, + engine: str = "kustokql", + ast: str | None = None, + ): + super().__init__(statement, engine, ast) + @classmethod def split_script( cls, @@ -604,7 +1069,7 @@ class KustoKQLStatement(BaseSQLStatement[str]): """ Pretty-format the SQL statement. """ - return self._sql.strip() + return self._parsed.strip() def get_settings(self) -> dict[str, str | bool]: """ @@ -621,6 +1086,12 @@ class KustoKQLStatement(BaseSQLStatement[str]): return {} + def is_select(self) -> bool: + """ + Check if the statement is a `SELECT` statement. + """ + return not self._parsed.startswith(".") + def is_mutating(self) -> bool: """ Check if the statement mutates data (DDL/DML). @@ -635,7 +1106,7 @@ class KustoKQLStatement(BaseSQLStatement[str]): Kusto KQL doesn't support optimization, so this method is a no-op. """ - return KustoKQLStatement(self._sql, self.engine, self._parsed) + return KustoKQLStatement(ast=self._parsed, engine=self.engine) def check_functions_present(self, functions: set[str]) -> bool: """ @@ -647,6 +1118,66 @@ class KustoKQLStatement(BaseSQLStatement[str]): logger.warning("Kusto KQL doesn't support checking for functions present.") return True + def get_limit_value(self) -> int | None: + """ + Get the limit value of the statement. + """ + tokens = [ + token + for token in tokenize_kql(self._parsed) + if token[0] != KQLTokenType.WHITESPACE + ] + for idx, (ttype, val) in enumerate(tokens): + if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}: + if idx + 1 < len(tokens) and tokens[idx + 1][0] == KQLTokenType.NUMBER: + return int(tokens[idx + 1][1]) + break + + return None + + def set_limit_value( + self, + limit: int, + method: LimitMethod = LimitMethod.FORCE_LIMIT, + ) -> None: + """ + Add a limit to the statement. + """ + if method != LimitMethod.FORCE_LIMIT: + raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.") + + tokens = tokenize_kql(self._parsed) + found_limit_token = False + for idx, (ttype, val) in enumerate(tokens): + if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}: + found_limit_token = True + + if found_limit_token and ttype == KQLTokenType.NUMBER: + tokens[idx] = (KQLTokenType.NUMBER, str(limit)) + break + else: + tokens.extend( + [ + (KQLTokenType.WHITESPACE, " "), + (KQLTokenType.WORD, "|"), + (KQLTokenType.WHITESPACE, " "), + (KQLTokenType.WORD, "take"), + (KQLTokenType.WHITESPACE, " "), + (KQLTokenType.NUMBER, str(limit)), + ] + ) + + self._parsed = "".join(val for _, val in tokens) + + def parse_predicate(self, predicate: str) -> str: + """ + Parse a predicate string into an AST. + + :param predicate: The predicate to parse. + :return: The parsed predicate. + """ + return predicate + class SQLScript: """ @@ -724,6 +1255,24 @@ class SQLScript: for statement in self.statements ) + def is_valid_ctas(self) -> bool: + """ + Check if the script contains a valid CTAS statement. + + CTAS (`CREATE TABLE AS SELECT`) can only be run with scripts where the last + statement is a `SELECT`. + """ + return self.statements[-1].is_select() + + def is_valid_cvas(self) -> bool: + """ + Check if the script contains a valid CVAS statement. + + CVAS (`CREATE VIEW AS SELECT`) can only be run with scripts with a single + `SELECT` statement. + """ + return len(self.statements) == 1 and self.statements[0].is_select() + def extract_tables_from_statement( statement: exp.Expression, @@ -732,7 +1281,7 @@ def extract_tables_from_statement( """ Extract all table references in a single statement. - Please not that this is not trivial; consider the following queries: + Please note that this is not trivial; consider the following queries: DESCRIBE some_table; SHOW PARTITIONS FROM some_table; diff --git a/superset/sql_lab.py b/superset/sql_lab.py index c157896fcc6..d36a530c26c 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=consider-using-transaction + +from __future__ import annotations + import dataclasses import logging import sys @@ -22,13 +25,14 @@ import uuid from contextlib import closing from datetime import datetime from sys import getsizeof -from typing import Any, cast, Optional, Union +from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union import backoff import msgpack from celery.exceptions import SoftTimeLimitExceeded from flask import current_app from flask_babel import gettext as __ +from sqlalchemy import and_, or_ from superset import ( app, @@ -39,27 +43,24 @@ from superset import ( security_manager, ) from superset.common.db_query_status import QueryStatus +from superset.connectors.sqla.models import SqlaTable from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY from superset.dataframe import df_to_records from superset.db_engine_specs import BaseEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( OAuth2RedirectError, + SupersetDMLNotAllowedException, SupersetErrorException, SupersetErrorsException, - SupersetParseError, + SupersetInvalidCTASException, + SupersetInvalidCVASException, + SupersetResultsBackendNotConfigureException, ) from superset.extensions import celery_app, event_logger -from superset.models.core import Database from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet -from superset.sql.parse import SQLStatement, Table -from superset.sql_parse import ( - CtasMethod, - insert_rls_as_subquery, - insert_rls_in_predicate, - ParsedQuery, -) +from superset.sql.parse import BaseSQLStatement, CTASMethod, SQLScript, Table from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.utils import write_ipc_buffer from superset.utils import json @@ -71,6 +72,9 @@ from superset.utils.core import ( from superset.utils.dates import now_as_float from superset.utils.decorators import stats_timing +if TYPE_CHECKING: + from superset.models.core import Database + config = app.config stats_logger = config["STATS_LOGGER"] SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"] @@ -197,97 +201,144 @@ def get_sql_results( # pylint: disable=too-many-arguments return handle_query_error(ex, query) -def execute_sql_statement( # pylint: disable=too-many-statements, too-many-locals # noqa: C901 - sql_statement: str, +def apply_rls( + database: Database, + catalog: str | None, + schema: str, + parsed_statement: BaseSQLStatement[Any], +) -> None: + """ + Modify statement inplace to ensure RLS rules are applied. + """ + # There are two ways to insert RLS: either replacing the table with a subquery + # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is + # safer, but not supported in all databases. + method = database.db_engine_spec.get_rls_method() + + # collect all RLS predicates for all tables in the query + predicates: dict[Table, list[Any]] = {} + for table in parsed_statement.tables: + # fully qualify table + table = Table( + table.table, + table.schema or schema, + table.catalog or catalog, + ) + + predicates[table] = [ + parsed_statement.parse_predicate(predicate) + for predicate in get_predicates_for_table( + table, + database, + database.get_default_catalog(), + ) + if predicate + ] + + parsed_statement.apply_rls(catalog, schema, predicates, method) + + +def get_predicates_for_table( + table: Table, + database: Database, + default_catalog: str | None, +) -> list[str]: + """ + Get the RLS predicates for a table. + + This is used to inject RLS rules into SQL statements run in SQL Lab. Note that the + table must be fully qualified, with catalog (null if the DB doesn't support) and + schema. + """ + # if the dataset in the RLS has null catalog, match it when using the default + # catalog + catalog_predicate = SqlaTable.catalog == table.catalog + if table.catalog and table.catalog == default_catalog: + catalog_predicate = or_( + catalog_predicate, + SqlaTable.catalog.is_(None), + ) + + dataset = ( + db.session.query(SqlaTable) + .filter( + and_( + SqlaTable.database_id == database.id, + catalog_predicate, + SqlaTable.schema == table.schema, + SqlaTable.table_name == table.table, + ) + ) + .one_or_none() + ) + if not dataset: + return [] + + return [ + str( + predicate.compile( + dialect=database.get_dialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + for predicate in dataset.get_sqla_row_level_filters() + ] + + +S = TypeVar("S", bound=BaseSQLStatement[Any]) + + +def apply_ctas(query: Query, parsed_statement: S) -> S: + """ + Apply CTAS/CVAS. + """ + if not query.tmp_table_name: + start_dttm = datetime.fromtimestamp(query.start_time) + prefix = f"tmp_{query.user_id}_table" + query.tmp_table_name = start_dttm.strftime(f"{prefix}_%Y_%m_%d_%H_%M_%S") + + catalog = ( + query.catalog + if query.database.db_engine_spec.supports_cross_catalog_queries + else None + ) + table = Table(query.tmp_table_name, query.tmp_schema_name, catalog) + method = CTASMethod[query.ctas_method.upper()] + + return parsed_statement.as_create_table(table, method) # type: ignore[return-value] + + +def apply_limit(query: Query, parsed_statement: BaseSQLStatement[Any]) -> None: + """ + Apply limit to the SQL statement. + """ + # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true + if parsed_statement.is_mutating() or ( + query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT + ): + return + + if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): + query.limit = SQL_MAX_ROW + + if query.limit: + parsed_statement.set_limit_value( + # fetch an extra row to inform user if there are more rows + query.limit + 1, + query.database.db_engine_spec.limit_method, + ) + + +def execute_query( # pylint: disable=too-many-statements, too-many-locals # noqa: C901 query: Query, cursor: Any, - log_params: Optional[dict[str, Any]], - apply_ctas: bool = False, + log_params: Optional[dict[str, Any]] = None, ) -> SupersetResultSet: """Executes a single SQL statement""" database: Database = query.database db_engine_spec = database.db_engine_spec - parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine) - if is_feature_enabled("RLS_IN_SQLLAB"): - # There are two ways to insert RLS: either replacing the table with a subquery - # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is - # safer, but not supported in all databases. - insert_rls = ( - insert_rls_as_subquery - if database.db_engine_spec.allows_subqueries - and database.db_engine_spec.allows_alias_in_select - else insert_rls_in_predicate - ) - - # Insert any applicable RLS predicates - parsed_query = ParsedQuery( - str( - insert_rls( - parsed_query._parsed[0], # pylint: disable=protected-access - database.id, - query.schema, - ) - ), - engine=db_engine_spec.engine, - ) - - sql = parsed_query.stripped() - - # This is a test to see if the query is being - # limited by either the dropdown or the sql. - # We are testing to see if more rows exist than the limit. - increased_limit = None if query.limit is None else query.limit + 1 - - if not database.allow_dml: - errors = [] - try: - parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine) - disallowed = parsed_statement.is_mutating() - except SupersetParseError as ex: - # if we fail to parse the query, disallow by default - disallowed = True - errors.append(ex.error) - - if disallowed: - errors.append( - SupersetError( - message=__( - "This database does not allow for DDL/DML, and the query " - "could not be parsed to confirm it is a read-only query. Please " # noqa: E501 - "contact your administrator for more assistance." - ), - error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR, - level=ErrorLevel.ERROR, - ) - ) - raise SupersetErrorsException(errors) - - if apply_ctas: - if not query.tmp_table_name: - start_dttm = datetime.fromtimestamp(query.start_time) - query.tmp_table_name = ( - f"tmp_{query.user_id}_table_{start_dttm.strftime('%Y_%m_%d_%H_%M_%S')}" - ) - sql = parsed_query.as_create_table( - query.tmp_table_name, - schema_name=query.tmp_schema_name, - method=query.ctas_method, - ) - query.select_as_cta_used = True - - # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true - if db_engine_spec.is_select_query(parsed_query) and not ( - query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT - ): - if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): - query.limit = SQL_MAX_ROW - sql = apply_limit_if_exists(database, increased_limit, query, sql) - - # Hook to allow environment-specific mutation (usually comments) to the SQL - sql = database.mutate_sql_based_on_config(sql) try: - query.executed_sql = sql if log_query: log_query( query.database.sqlalchemy_uri, @@ -304,7 +355,7 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca object_ref=__name__, ): with stats_timing("sqllab.query.time_executing_query", stats_logger): - db_engine_spec.execute_with_cursor(cursor, sql, query) + db_engine_spec.execute_with_cursor(cursor, query.executed_sql, query) with stats_timing("sqllab.query.time_fetching_results", stats_logger): logger.debug( @@ -312,6 +363,7 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca query.id, str(query.to_dict()), ) + increased_limit = None if query.limit is None else query.limit + 1 data = db_engine_spec.fetch_data(cursor, increased_limit) if query.limit is None or len(data) <= query.limit: query.limiting_factor = LimitingFactor.NOT_LIMITED @@ -352,19 +404,6 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca return SupersetResultSet(data, cursor_description, db_engine_spec) -def apply_limit_if_exists( - database: Database, increased_limit: Optional[int], query: Query, sql: str -) -> str: - if query.limit and increased_limit: - # We are fetching one more than the requested limit in order - # to test whether there are more rows than the limit. According to the DB - # Engine support it will choose top or limit parse - # Later, the extra row will be dropped before sending - # the results back to the user. - sql = database.apply_limit_to_sql(sql, increased_limit, force=True) - return sql - - def _serialize_payload( payload: dict[Any, Any], use_msgpack: Optional[bool] = False ) -> Union[bytes, str]: @@ -430,67 +469,53 @@ def execute_sql_statements( # noqa: C901 db_engine_spec.patch() if database.allow_run_async and not results_backend: - raise SupersetErrorException( - SupersetError( - message=__("Results backend is not configured."), - error_type=SupersetErrorType.RESULTS_BACKEND_NOT_CONFIGURED_ERROR, - level=ErrorLevel.ERROR, - ) - ) - - # Breaking down into multiple statements - parsed_query = ParsedQuery( - rendered_query, - engine=db_engine_spec.engine, - ) - if not db_engine_spec.run_multiple_statements_as_one: - statements = parsed_query.get_statements() - logger.info( - "Query %s: Executing %i statement(s)", str(query_id), len(statements) - ) - else: - statements = [rendered_query] - logger.info("Query %s: Executing query as a single statement", str(query_id)) + raise SupersetResultsBackendNotConfigureException() logger.info("Query %s: Set query to 'running'", str(query_id)) query.status = QueryStatus.RUNNING query.start_running_time = now_as_float() db.session.commit() - # Should we create a table or view from the select? - if ( - query.select_as_cta - and query.ctas_method == CtasMethod.TABLE - and not parsed_query.is_valid_ctas() - ): - raise SupersetErrorException( - SupersetError( - message=__( - "CTAS (create table as select) can only be run with a query where " - "the last statement is a SELECT. Please make sure your query has " - "a SELECT as its last statement. Then, try running your query " - "again." - ), - error_type=SupersetErrorType.INVALID_CTAS_QUERY_ERROR, - level=ErrorLevel.ERROR, - ) - ) - if ( - query.select_as_cta - and query.ctas_method == CtasMethod.VIEW - and not parsed_query.is_valid_cvas() - ): - raise SupersetErrorException( - SupersetError( - message=__( - "CVAS (create view as select) can only be run with a query with " - "a single SELECT statement. Please make sure your query has only " - "a SELECT statement. Then, try running your query again." - ), - error_type=SupersetErrorType.INVALID_CVAS_QUERY_ERROR, - level=ErrorLevel.ERROR, - ) + parsed_script = SQLScript(rendered_query, engine=db_engine_spec.engine) + + if parsed_script.has_mutation() and not database.allow_dml: + raise SupersetDMLNotAllowedException() + + if is_feature_enabled("RLS_IN_SQLLAB"): + default_schema = query.database.get_default_schema_for_query(query) + for statement in parsed_script.statements: + apply_rls(query.database, query.catalog, default_schema, statement) + + if query.select_as_cta: + # CTAS is valid when the last statement is a SELECT, while CVAS is valid when + # there is only a single statement which must be a SELECT. + if ( + query.ctas_method == CTASMethod.TABLE.name + and not parsed_script.is_valid_ctas() + ): + raise SupersetInvalidCTASException() + if ( + query.ctas_method == CTASMethod.VIEW.name + and not parsed_script.is_valid_cvas() + ): + raise SupersetInvalidCVASException() + + parsed_script.statements[-1] = apply_ctas( # type: ignore + query, + parsed_script.statements[-1], ) + query.select_as_cta_used = True + + for statement in parsed_script.statements: + apply_limit(query, statement) + + # some databases (like BigQuery and Kusto) do not persist state across mmultiple + # statements if they're run separately (especially when using `NullPool`), so we run + # the query as a single block. + if db_engine_spec.run_multiple_statements_as_one: + blocks = [parsed_script.format()] + else: + blocks = [statement.format() for statement in parsed_script.statements] with database.get_raw_connection( catalog=query.catalog, @@ -500,40 +525,35 @@ def execute_sql_statements( # noqa: C901 # Sharing a single connection and cursor across the # execution of all statements (if many) cursor = conn.cursor() + cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) if cancel_query_id is not None: query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id) db.session.commit() - statement_count = len(statements) - for i, statement in enumerate(statements): + + block_count = len(blocks) + for i, block in enumerate(blocks): # Check if stopped db.session.refresh(query) if query.status == QueryStatus.STOPPED: payload.update({"status": query.status}) return payload - # For CTAS we create the table only on the last statement - apply_ctas = query.select_as_cta and ( - query.ctas_method == CtasMethod.VIEW - or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1) - ) + # Run statement msg = __( - "Running statement %(statement_num)s out of %(statement_count)s", - statement_num=i + 1, - statement_count=statement_count, + "Running block %(block_num)s out of %(block_count)s", + block_num=i + 1, + block_count=block_count, ) logger.info("Query %s: %s", str(query_id), msg) query.set_extra_json_key("progress", msg) db.session.commit() - try: - result_set = execute_sql_statement( - statement, - query, - cursor, - log_params, - apply_ctas, - ) + # Hook to allow environment-specific mutation (usually comments) to the SQL + query.executed_sql = database.mutate_sql_based_on_config(block) + + try: + result_set = execute_query(query, cursor, log_params) except SqlLabQueryStoppedException: payload.update({"status": QueryStatus.STOPPED}) return payload @@ -541,22 +561,18 @@ def execute_sql_statements( # noqa: C901 msg = str(ex) prefix_message = ( __( - "Statement %(statement_num)s out of %(statement_count)s", - statement_num=i + 1, - statement_count=statement_count, + "Block %(block_num)s out of %(block_count)s", + block_num=i + 1, + block_count=block_count, ) - if statement_count > 1 + if block_count > 1 else "" ) payload = handle_query_error(ex, query, payload, prefix_message) return payload # Commit the connection so CTA queries will create the table and any DML. - should_commit = ( - not db_engine_spec.is_select_query(parsed_query) # check if query is DML - or apply_ctas - ) - if should_commit: + if parsed_script.has_mutation() or query.select_as_cta: conn.commit() # Success, updating the query entry in database diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 8fae4507efa..6422a5c7bbe 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -469,39 +469,6 @@ class ParsedQuery: exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}" return exec_sql - def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str: - """Returns the query with the specified limit. - - Does not change the underlying query if user did not apply the limit, - otherwise replaces the limit with the lower value between existing limit - in the query and new_limit. - - :param new_limit: Limit to be incorporated into returned query - :return: The original query with new limit - """ - if not self._limit: - return f"{self.stripped()}\nLIMIT {new_limit}" - limit_pos = None - statement = self._parsed[0] - # Add all items to before_str until there is a limit - for pos, item in enumerate(statement.tokens): - if item.ttype in Keyword and item.value.lower() == "limit": - limit_pos = pos - break - _, limit = statement.token_next(idx=limit_pos) - # Override the limit only when it exceeds the configured value. - if limit.ttype == sqlparse.tokens.Literal.Number.Integer and ( - force or new_limit < int(limit.value) - ): - limit.value = new_limit - elif limit.is_group: - limit.value = f"{next(limit.get_identifiers())}, {new_limit}" - - str_res = "" - for i in statement.tokens: - str_res += str(i.value) - return str_res - def sanitize_clause(clause: str) -> str: # clause = sqlparse.format(clause, strip_comments=True) @@ -566,7 +533,7 @@ def has_table_query(expression: str, engine: str) -> bool: expression = f"({expression})" sql = f"SELECT {expression}" - statement = SQLStatement(sql, engine) + statement = SQLStatement(statement=sql, engine=engine) return any(statement.tables) @@ -972,12 +939,12 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]: node.data = "NULL" # re-render template back into a string - rendered_template = Template(template).render() + rendered_sql = Template(template).render(processor.get_context()) return ( tables | ParsedQuery( - sql_statement=processor.process_template(rendered_template), + sql_statement=processor.process_template(rendered_sql), engine=database.db_engine_spec.engine, ).tables ) diff --git a/superset/sqllab/sqllab_execution_context.py b/superset/sqllab/sqllab_execution_context.py index ab0f91bbf30..0e579ede9b6 100644 --- a/superset/sqllab/sqllab_execution_context.py +++ b/superset/sqllab/sqllab_execution_context.py @@ -26,7 +26,7 @@ from sqlalchemy.orm.exc import DetachedInstanceError from superset import is_feature_enabled from superset.models.sql_lab import Query -from superset.sql_parse import CtasMethod +from superset.sql.parse import CTASMethod from superset.utils import core as utils, json from superset.utils.core import apply_max_row_limit, get_user_id from superset.utils.dates import now_as_float @@ -148,6 +148,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes def create_query(self) -> Query: start_time = now_as_float() + ctas = cast(CreateTableAsSelect, self.create_table_as_select) if self.select_as_cta: return Query( database_id=self.database_id, @@ -155,14 +156,14 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes catalog=self.catalog, schema=self.schema, select_as_cta=True, - ctas_method=self.create_table_as_select.ctas_method, # type: ignore + ctas_method=ctas.ctas_method.name, start_time=start_time, tab_name=self.tab_name, status=self.status, limit=self.limit, sql_editor_id=self.sql_editor_id, - tmp_table_name=self.create_table_as_select.target_table_name, # type: ignore - tmp_schema_name=self.create_table_as_select.target_schema_name, # type: ignore + tmp_table_name=ctas.target_table_name, + tmp_schema_name=ctas.target_schema_name, user_id=self.user_id, client_id=self.client_id_or_short_id, ) @@ -190,12 +191,12 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes class CreateTableAsSelect: # pylint: disable=too-few-public-methods - ctas_method: CtasMethod + ctas_method: CTASMethod target_schema_name: str | None target_table_name: str def __init__( - self, ctas_method: CtasMethod, target_schema_name: str, target_table_name: str + self, ctas_method: CTASMethod, target_schema_name: str, target_table_name: str ): self.ctas_method = ctas_method self.target_schema_name = target_schema_name @@ -203,7 +204,7 @@ class CreateTableAsSelect: # pylint: disable=too-few-public-methods @staticmethod def create_from(query_params: dict[str, Any]) -> CreateTableAsSelect: - ctas_method = query_params.get("ctas_method", CtasMethod.TABLE) + ctas_method = CTASMethod[query_params.get("ctas_method", "table").upper()] schema = cast(str, query_params.get("schema")) tmp_table_name = cast(str, query_params.get("tmp_table_name")) return CreateTableAsSelect(ctas_method, schema, tmp_table_name) diff --git a/superset/tags/api.py b/superset/tags/api.py index 66a22d21eb2..2d92b2ff5c6 100644 --- a/superset/tags/api.py +++ b/superset/tags/api.py @@ -583,9 +583,9 @@ class TagRestApi(BaseSupersetModelRestApi): if tag_ids: # priotize using ids for lookups vs. names mainly using this # for backward compatibility - tagged_objects = TagDAO.get_tagged_objects_by_tag_id(tag_ids, types) + tagged_objects = TagDAO.get_tagged_objects_by_tag_ids(tag_ids, types) else: - tagged_objects = TagDAO.get_tagged_objects_for_tags(tags, types) + tagged_objects = TagDAO.get_tagged_objects_by_tag_names(tags, types) result = [ self.object_entity_response_schema.dump(tagged_object) diff --git a/superset/utils/date_parser.py b/superset/utils/date_parser.py index 585c837f9d0..2db8deb1ecb 100644 --- a/superset/utils/date_parser.py +++ b/superset/utils/date_parser.py @@ -77,7 +77,7 @@ def parse_human_datetime(human_readable: str) -> datetime: def normalize_time_delta(human_readable: str) -> dict[str, int]: - x_unit = r"^\s*([0-9]+)\s+(second|minute|hour|day|week|month|quarter|year)s?\s+(ago|later)*$" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + x_unit = r"^\s*([0-9]+)\s+(second|minute|hour|day|week|month|quarter|year)s?\s+(ago|later)*$" # noqa: E501 matched = re.match(x_unit, human_readable, re.IGNORECASE) if not matched: raise TimeDeltaAmbiguousError(human_readable) @@ -362,13 +362,13 @@ def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-m and time_range.startswith("previous calendar week") and separator not in time_range ): - time_range = "DATETRUNC(DATEADD(DATETIME('today'), -1, WEEK), WEEK) : DATETRUNC(DATETIME('today'), WEEK)" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + time_range = "DATETRUNC(DATEADD(DATETIME('today'), -1, WEEK), WEEK) : DATETRUNC(DATETIME('today'), WEEK)" # noqa: E501 if ( time_range and time_range.startswith("previous calendar month") and separator not in time_range ): - time_range = "DATETRUNC(DATEADD(DATETIME('today'), -1, MONTH), MONTH) : DATETRUNC(DATETIME('today'), MONTH)" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + time_range = "DATETRUNC(DATEADD(DATETIME('today'), -1, MONTH), MONTH) : DATETRUNC(DATETIME('today'), MONTH)" # noqa: E501 if ( time_range and time_range.startswith("previous calendar quarter") @@ -376,44 +376,44 @@ def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-m ): time_range = ( "DATETRUNC(DATEADD(DATETIME('today'), -1, QUARTER), QUARTER) : " - "DATETRUNC(DATETIME('today'), QUARTER)" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + "DATETRUNC(DATETIME('today'), QUARTER)" # noqa: E501 ) if ( time_range and time_range.startswith("previous calendar year") and separator not in time_range ): - time_range = "DATETRUNC(DATEADD(DATETIME('today'), -1, YEAR), YEAR) : DATETRUNC(DATETIME('today'), YEAR)" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + time_range = "DATETRUNC(DATEADD(DATETIME('today'), -1, YEAR), YEAR) : DATETRUNC(DATETIME('today'), YEAR)" # noqa: E501 if ( time_range and time_range.startswith("Current day") and separator not in time_range ): - time_range = "DATETRUNC(DATEADD(DATETIME('today'), 0, DAY), DAY) : DATETRUNC(DATEADD(DATETIME('today'), 1, DAY), DAY)" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + time_range = "DATETRUNC(DATEADD(DATETIME('today'), 0, DAY), DAY) : DATETRUNC(DATEADD(DATETIME('today'), 1, DAY), DAY)" # noqa: E501 if ( time_range and time_range.startswith("Current week") and separator not in time_range ): - time_range = "DATETRUNC(DATEADD(DATETIME('today'), 0, WEEK), WEEK) : DATETRUNC(DATEADD(DATETIME('today'), 1, WEEK), WEEK)" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + time_range = "DATETRUNC(DATEADD(DATETIME('today'), 0, WEEK), WEEK) : DATETRUNC(DATEADD(DATETIME('today'), 1, WEEK), WEEK)" # noqa: E501 if ( time_range and time_range.startswith("Current month") and separator not in time_range ): - time_range = "DATETRUNC(DATEADD(DATETIME('today'), 0, MONTH), MONTH) : DATETRUNC(DATEADD(DATETIME('today'), 1, MONTH), MONTH)" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + time_range = "DATETRUNC(DATEADD(DATETIME('today'), 0, MONTH), MONTH) : DATETRUNC(DATEADD(DATETIME('today'), 1, MONTH), MONTH)" # noqa: E501 if ( time_range and time_range.startswith("Current quarter") and separator not in time_range ): - time_range = "DATETRUNC(DATEADD(DATETIME('today'), 0, QUARTER), QUARTER) : DATETRUNC(DATEADD(DATETIME('today'), 1, QUARTER), QUARTER)" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + time_range = "DATETRUNC(DATEADD(DATETIME('today'), 0, QUARTER), QUARTER) : DATETRUNC(DATEADD(DATETIME('today'), 1, QUARTER), QUARTER)" # noqa: E501 if ( time_range and time_range.startswith("Current year") and separator not in time_range ): - time_range = "DATETRUNC(DATEADD(DATETIME('today'), 0, YEAR), YEAR) : DATETRUNC(DATEADD(DATETIME('today'), 1, YEAR), YEAR)" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + time_range = "DATETRUNC(DATEADD(DATETIME('today'), 0, YEAR), YEAR) : DATETRUNC(DATEADD(DATETIME('today'), 1, YEAR), YEAR)" # noqa: E501 if time_range and separator in time_range: time_range_lookup = [ @@ -421,7 +421,7 @@ def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-m r"^(start of|beginning of|end of)\s+" r"(this|last|next|prior)\s+" r"([0-9]+)?\s*" - r"(day|week|month|quarter|year)s?$", # Matches phrases like "start of next month" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + r"(day|week|month|quarter|year)s?$", # Matches phrases like "start of next month" # noqa: E501 lambda modifier, scope, delta, unit: handle_modifier_and_unit( modifier, scope, @@ -433,13 +433,13 @@ def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-m ( r"^(this|last|next|prior)\s+" r"([0-9]+)?\s*" - r"(second|minute|day|week|month|quarter|year)s?$", # Matches "next 5 days" or "last 2 weeks" # pylint: disable=line-too-long,useless-suppression # noqa: E501 + r"(second|minute|day|week|month|quarter|year)s?$", # Matches "next 5 days" or "last 2 weeks" # noqa: E501 lambda scope, delta, unit: handle_scope_and_unit( scope, delta, unit, get_relative_base(unit, relative_start) ), ), ( - r"^(DATETIME.*|DATEADD.*|DATETRUNC.*|LASTDAY.*|HOLIDAY.*)$", # Matches date-related keywords # pylint: disable=line-too-long,useless-suppression # noqa: E501 + r"^(DATETIME.*|DATEADD.*|DATETRUNC.*|LASTDAY.*|HOLIDAY.*)$", # Matches date-related keywords # noqa: E501 lambda text: text, ), ] diff --git a/superset/views/all_entities.py b/superset/views/all_entities.py index 5ca7e41e1d7..d78c6385898 100644 --- a/superset/views/all_entities.py +++ b/superset/views/all_entities.py @@ -22,6 +22,7 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access from superset import is_feature_enabled +from superset.constants import RouteMethod from superset.superset_typing import FlaskResponse from superset.tags.models import Tag from superset.views.base import SupersetModelView @@ -33,7 +34,7 @@ class TaggedObjectsModelView(SupersetModelView): route_base = "/superset/all_entities" datamodel = SQLAInterface(Tag) class_permission_name = "Tags" - include_route_methods = {"list"} + include_route_methods = {RouteMethod.LIST} @has_access @expose("/") diff --git a/superset/views/log/api.py b/superset/views/log/api.py index ffa3a860060..53a7e08d7f7 100644 --- a/superset/views/log/api.py +++ b/superset/views/log/api.py @@ -19,6 +19,7 @@ from typing import Any, Optional from flask import current_app as app from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.hooks import before_request +from flask_appbuilder.models.sqla.filters import FilterRelationOneToManyEqual from flask_appbuilder.models.sqla.interface import SQLAInterface import superset.models.core as models @@ -45,7 +46,8 @@ class LogRestApi(LogMixin, BaseSupersetModelRestApi): resource_name = "log" allow_browser_login = True list_columns = [ - "user.username", + "user", + "user_id", "action", "dttm", "json", @@ -55,6 +57,21 @@ class LogRestApi(LogMixin, BaseSupersetModelRestApi): "duration_ms", "referrer", ] + search_columns = [ + "user", + "user_id", + "action", + "dttm", + "json", + "slice_id", + "dashboard_id", + "user_id", + "duration_ms", + "referrer", + ] + search_filters = { + "user": [FilterRelationOneToManyEqual], + } show_columns = list_columns page_size = 20 apispec_parameter_schemas = { diff --git a/tests/integration_tests/db_engine_specs/base_tests.py b/superset/views/logs.py similarity index 54% rename from tests/integration_tests/db_engine_specs/base_tests.py rename to superset/views/logs.py index c836e71b689..39fe10d0a3e 100644 --- a/tests/integration_tests/db_engine_specs/base_tests.py +++ b/superset/views/logs.py @@ -14,23 +14,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# isort:skip_file +from flask_appbuilder import permission_name +from flask_appbuilder.api import expose +from flask_appbuilder.security.decorators import has_access -from tests.integration_tests.test_app import app # noqa: F401 -from tests.integration_tests.base_tests import SupersetTestCase -from superset.db_engine_specs.base import BaseEngineSpec -from superset.models.core import Database +from superset.superset_typing import FlaskResponse + +from .base import BaseSupersetView -class TestDbEngineSpec(SupersetTestCase): - def sql_limit_regex( - self, - sql, - expected_sql, - engine_spec_class=BaseEngineSpec, - limit=1000, - force=False, - ): - main = Database(database_name="test_database", sqlalchemy_uri="sqlite://") - limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force) - assert expected_sql == limited +class ActionLogView(BaseSupersetView): + route_base = "/" + class_permission_name = "security" + + @expose("/actionlog/list") + @has_access + @permission_name("read") + def list(self) -> FlaskResponse: + return super().render_app_template() diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 5ddda0c5912..1cf91ca2bbd 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -36,7 +36,7 @@ from sqlalchemy.dialects.mysql import dialect from tests.integration_tests.constants import ADMIN_USERNAME from tests.integration_tests.test_app import app, login -from superset.sql_parse import CtasMethod +from superset.sql.parse import CTASMethod from superset import db, security_manager from superset.connectors.sqla.models import BaseDatasource, SqlaTable from superset.models import core as models @@ -387,7 +387,7 @@ class SupersetTestCase(TestCase): select_as_cta=False, tmp_table_name=None, schema=None, - ctas_method=CtasMethod.TABLE, + ctas_method=CTASMethod.TABLE, template_params="{}", ): if username: @@ -400,7 +400,7 @@ class SupersetTestCase(TestCase): "client_id": client_id, "queryLimit": query_limit, "sql_editor_id": sql_editor_id, - "ctas_method": ctas_method, + "ctas_method": ctas_method.name, "templateParams": template_params, } if tmp_table_name: diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index 050f1e1dc26..b517c20d2de 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -39,7 +39,7 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import ErrorLevel, SupersetErrorType from superset.extensions import celery_app from superset.models.sql_lab import Query -from superset.sql_parse import ParsedQuery, CtasMethod +from superset.sql.parse import CTASMethod from superset.utils.core import backend from superset.utils.database import get_example_database from tests.integration_tests.conftest import CTAS_SCHEMA_NAME @@ -76,13 +76,19 @@ def setup_sqllab(): db.session.query(Query).delete() db.session.commit() for tbl in TMP_TABLES: - drop_table_if_exists(f"{tbl}_{CtasMethod.TABLE.lower()}", CtasMethod.TABLE) - drop_table_if_exists(f"{tbl}_{CtasMethod.VIEW.lower()}", CtasMethod.VIEW) drop_table_if_exists( - f"{CTAS_SCHEMA_NAME}.{tbl}_{CtasMethod.TABLE.lower()}", CtasMethod.TABLE + f"{tbl}_{CTASMethod.TABLE.name.lower()}", CTASMethod.TABLE ) drop_table_if_exists( - f"{CTAS_SCHEMA_NAME}.{tbl}_{CtasMethod.VIEW.lower()}", CtasMethod.VIEW + f"{tbl}_{CTASMethod.VIEW.name.lower()}", CTASMethod.VIEW + ) + drop_table_if_exists( + f"{CTAS_SCHEMA_NAME}.{tbl}_{CTASMethod.TABLE.name.lower()}", + CTASMethod.TABLE, + ) + drop_table_if_exists( + f"{CTAS_SCHEMA_NAME}.{tbl}_{CTASMethod.VIEW.name.lower()}", + CTASMethod.VIEW, ) @@ -90,7 +96,7 @@ def run_sql( test_client, sql, cta=False, - ctas_method=CtasMethod.TABLE, + ctas_method=CTASMethod.TABLE, tmp_table="tmp", async_=False, ): @@ -104,14 +110,14 @@ def run_sql( select_as_cta=cta, tmp_table_name=tmp_table, client_id="".join(random.choice(string.ascii_lowercase) for i in range(5)), # noqa: S311 - ctas_method=ctas_method, + ctas_method=ctas_method.name, ), ).json -def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None: +def drop_table_if_exists(table_name: str, table_type: CTASMethod) -> None: """Drop table if it exists, works on any DB""" - sql = f"DROP {table_type} IF EXISTS {table_name}" + sql = f"DROP {table_type.name} IF EXISTS {table_name}" database = get_example_database() with database.get_sqla_engine() as engine: engine.execute(sql) @@ -124,10 +130,10 @@ def quote_f(value: Optional[str]): return inspector.engine.dialect.identifier_preparer.quote_identifier(value) -def cta_result(ctas_method: CtasMethod): +def cta_result(ctas_method: CTASMethod): if backend() != "presto": return [], [] - if ctas_method == CtasMethod.TABLE: + if ctas_method == CTASMethod.TABLE: return [{"rows": 1}], [{"name": "rows", "type": "BIGINT", "is_dttm": False}] return [{"result": True}], [{"name": "result", "type": "BOOLEAN", "is_dttm": False}] @@ -143,13 +149,13 @@ def get_select_star(table: str, limit: int, schema: Optional[str] = None): @pytest.mark.usefixtures("login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +@pytest.mark.parametrize("ctas_method", [CTASMethod.TABLE, CTASMethod.VIEW]) def test_run_sync_query_dont_exist(test_client, ctas_method): examples_db = get_example_database() engine_name = examples_db.db_engine_spec.engine_name sql_dont_exist = "SELECT name FROM table_dont_exist" result = run_sql(test_client, sql_dont_exist, cta=True, ctas_method=ctas_method) - if backend() == "sqlite" and ctas_method == CtasMethod.VIEW: + if backend() == "sqlite" and ctas_method == CTASMethod.VIEW: assert QueryStatus.SUCCESS == result["status"], result elif backend() == "presto": assert ( @@ -188,9 +194,9 @@ def test_run_sync_query_dont_exist(test_client, ctas_method): @pytest.mark.usefixtures("load_birth_names_data", "login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) -def test_run_sync_query_cta(test_client, ctas_method): - tmp_table_name = f"{TEST_SYNC}_{ctas_method.lower()}" +@pytest.mark.parametrize("ctas_method", [CTASMethod.TABLE, CTASMethod.VIEW]) +def test_run_sync_query_cta(test_client, ctas_method: CTASMethod) -> None: + tmp_table_name = f"{TEST_SYNC}_{ctas_method.name.lower()}" result = run_sql( test_client, QUERY, tmp_table=tmp_table_name, cta=True, ctas_method=ctas_method ) @@ -218,16 +224,44 @@ def test_run_sync_query_cta_no_data(test_client): @pytest.mark.usefixtures("load_birth_names_data", "login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +@pytest.mark.parametrize( + "ctas_method, expected", + [ + ( + CTASMethod.TABLE, + """ +CREATE TABLE sqllab_test_db.test_sync_cta_table AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ( + CTASMethod.VIEW, + """ +CREATE VIEW sqllab_test_db.test_sync_cta_view AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ], +) @mock.patch( # noqa: PT008 "superset.sqllab.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) -def test_run_sync_query_cta_config(test_client, ctas_method): +def test_run_sync_query_cta_config( + test_client, + ctas_method: CTASMethod, + expected: str, +) -> None: if backend() == "sqlite": # sqlite doesn't support schemas return - tmp_table_name = f"{TEST_SYNC_CTA}_{ctas_method.lower()}" + tmp_table_name = f"{TEST_SYNC_CTA}_{ctas_method.name.lower()}" result = run_sql( test_client, QUERY, cta=True, ctas_method=ctas_method, tmp_table=tmp_table_name ) @@ -235,10 +269,7 @@ def test_run_sync_query_cta_config(test_client, ctas_method): assert cta_result(ctas_method) == (result["data"], result["columns"]) query = get_query_by_id(result["query"]["serverId"]) - assert ( - f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n{QUERY}" - == query.executed_sql - ) + assert query.executed_sql == expected assert query.select_sql == get_select_star( tmp_table_name, limit=query.limit, schema=CTAS_SCHEMA_NAME ) @@ -249,16 +280,44 @@ def test_run_sync_query_cta_config(test_client, ctas_method): @pytest.mark.usefixtures("load_birth_names_data", "login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +@pytest.mark.parametrize( + "ctas_method, expected", + [ + ( + CTASMethod.TABLE, + """ +CREATE TABLE sqllab_test_db.test_async_cta_config_table AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ( + CTASMethod.VIEW, + """ +CREATE VIEW sqllab_test_db.test_async_cta_config_view AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ], +) @mock.patch( # noqa: PT008 "superset.sqllab.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) -def test_run_async_query_cta_config(test_client, ctas_method): +def test_run_async_query_cta_config( + test_client, + ctas_method: CTASMethod, + expected: str, +) -> None: if backend() == "sqlite": # sqlite doesn't support schemas return - tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.lower()}" + tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.name.lower()}" result = run_sql( test_client, QUERY, @@ -275,18 +334,43 @@ def test_run_async_query_cta_config(test_client, ctas_method): get_select_star(tmp_table_name, limit=query.limit, schema=CTAS_SCHEMA_NAME) == query.select_sql ) - assert ( - f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n{QUERY}" - == query.executed_sql - ) + assert query.executed_sql == expected delete_tmp_view_or_table(f"{CTAS_SCHEMA_NAME}.{tmp_table_name}", ctas_method) @pytest.mark.usefixtures("load_birth_names_data", "login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) -def test_run_async_cta_query(test_client, ctas_method): - table_name = f"{TEST_ASYNC_CTA}_{ctas_method.lower()}" +@pytest.mark.parametrize( + "ctas_method, expected", + [ + ( + CTASMethod.TABLE, + """ +CREATE TABLE test_async_cta_table AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ( + CTASMethod.VIEW, + """ +CREATE VIEW test_async_cta_view AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ], +) +def test_run_async_cta_query( + test_client, + ctas_method: CTASMethod, + expected: str, +) -> None: + table_name = f"{TEST_ASYNC_CTA}_{ctas_method.name.lower()}" result = run_sql( test_client, QUERY, @@ -301,7 +385,7 @@ def test_run_async_cta_query(test_client, ctas_method): assert QueryStatus.SUCCESS == query.status assert get_select_star(table_name, query.limit) in query.select_sql - assert f"CREATE {ctas_method} {table_name} AS \n{QUERY}" == query.executed_sql + assert query.executed_sql == expected assert QUERY == query.sql assert query.rows == (1 if backend() == "presto" else 0) assert query.select_as_cta @@ -311,9 +395,37 @@ def test_run_async_cta_query(test_client, ctas_method): @pytest.mark.usefixtures("load_birth_names_data", "login_as_admin") -@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) -def test_run_async_cta_query_with_lower_limit(test_client, ctas_method): - tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.lower()}" +@pytest.mark.parametrize( + "ctas_method, expected", + [ + ( + CTASMethod.TABLE, + """ +CREATE TABLE test_async_lower_limit_table AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ( + CTASMethod.VIEW, + """ +CREATE VIEW test_async_lower_limit_view AS +SELECT + name +FROM birth_names +LIMIT 1 + """.strip(), + ), + ], +) +def test_run_async_cta_query_with_lower_limit( + test_client, + ctas_method: CTASMethod, + expected: str, +) -> None: + tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.name.lower()}" result = run_sql( test_client, QUERY, @@ -332,7 +444,7 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method): else get_select_star(tmp_table, query.limit) ) - assert f"CREATE {ctas_method} {tmp_table} AS \n{QUERY}" == query.executed_sql + assert query.executed_sql == expected assert QUERY == query.sql assert query.rows == (1 if backend() == "presto" else 0) @@ -442,28 +554,6 @@ def test_msgpack_payload_serialization(): assert isinstance(serialized, bytes) -def test_create_table_as(): - q = ParsedQuery("SELECT * FROM outer_space;") - - assert "CREATE TABLE tmp AS \nSELECT * FROM outer_space" == q.as_create_table("tmp") - assert ( - "DROP TABLE IF EXISTS tmp;\nCREATE TABLE tmp AS \nSELECT * FROM outer_space" - == q.as_create_table("tmp", overwrite=True) - ) - - # now without a semicolon - q = ParsedQuery("SELECT * FROM outer_space") - assert "CREATE TABLE tmp AS \nSELECT * FROM outer_space" == q.as_create_table("tmp") - - # now a multi-line query - multi_line_query = "SELECT * FROM planets WHERE\nLuke_Father = 'Darth Vader'" - q = ParsedQuery(multi_line_query) - assert ( - "CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\nLuke_Father = 'Darth Vader'" - == q.as_create_table("tmp") - ) - - def test_in_app_context(): @celery_app.task(bind=True) def my_task(self): @@ -484,8 +574,8 @@ def test_in_app_context(): ) -def delete_tmp_view_or_table(name: str, db_object_type: str): - db.get_engine().execute(f"DROP {db_object_type} IF EXISTS {name}") +def delete_tmp_view_or_table(name: str, ctas_method: CTASMethod): + db.get_engine().execute(f"DROP {ctas_method.name} IF EXISTS {name}") def wait_for_success(result): diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index fa5e7b44ba2..59364b4ed04 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -661,7 +661,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): ] rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") - assert rv.status_code == 400 + assert rv.status_code == 422 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_invalid_having_parameter_closing_and_comment__400(self): @@ -709,7 +709,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") result = rv.json["result"][0]["query"] if get_example_database().backend != "presto": - assert "('boy' = 'boy')" in result + assert "(\n 'boy' = 'boy'\n )" in result @unittest.skip("Extremely flaky test on MySQL") @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @@ -840,7 +840,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): unique_names = {row["name"] for row in data} self.maxDiff = None assert len(unique_names) == SERIES_LIMIT - assert {column for column in data[0].keys()} == {"state", "name", "sum__num"} # noqa: C416 + assert set(data[0]) == {"state", "name", "sum__num"} @pytest.mark.usefixtures( "create_annotation_layers", "load_birth_names_dashboard_with_slices" @@ -931,7 +931,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): assert rv.status_code == 200 result = rv.json["result"][0] data = result["data"] - assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"} # noqa: C416 + assert set(data[0]) == {"foo", "bar", "state", "count"} # make sure results and query parameters are unescaped assert {row["foo"] for row in data} == {":foo"} assert {row["bar"] for row in data} == {":bar:"} @@ -1251,7 +1251,7 @@ class TestGetChartDataApi(BaseTestChartDataApi): response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] data = result["data"] - assert {column for column in data[0].keys()} == {"male_or_female", "sum__num"} # noqa: C416 + assert set(data[0]) == {"male_or_female", "sum__num"} unique_genders = {row["male_or_female"] for row in data} assert unique_genders == {"male", "female"} assert result["applied_filters"] == [{"column": "male_or_female"}] @@ -1271,7 +1271,7 @@ class TestGetChartDataApi(BaseTestChartDataApi): response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] data = result["data"] - assert {column for column in data[0].keys()} == {"male_or_female", "sum__num"} # noqa: C416 + assert set(data[0]) == {"male_or_female", "sum__num"} unique_genders = {row["male_or_female"] for row in data} assert unique_genders == {"male", "female"} assert result["applied_filters"] == [{"column": "male_or_female"}] diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 0c6dad0b1df..3bf5d7d7652 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -1337,7 +1337,7 @@ class TestDatabaseApi(SupersetTestCase): expected_response_postgres = { "errors": [dataclasses.asdict(superset_error_postgres)] } - assert response.status_code == 500 + assert response.status_code == 400 if example_db.backend == "mysql": assert response_data == expected_response_mysql else: @@ -2450,7 +2450,7 @@ class TestDatabaseApi(SupersetTestCase): url = "api/v1/database/test_connection/" rv = self.post_assert_metric(url, data, "test_connection") - assert rv.status_code == 500 + assert rv.status_code == 400 assert rv.headers["Content-Type"] == "application/json; charset=utf-8" response = json.loads(rv.data.decode("utf-8")) expected_response = {"errors": [dataclasses.asdict(superset_error)]} diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index c9ba88411de..21162ca52d6 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -568,6 +568,9 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset): + if get_example_database().backend == "sqlite": + return + TableColumn( column_name="DUMMY CC", type="VARCHAR(255)", @@ -702,7 +705,7 @@ def test_get_samples_with_multiple_filters( assert "2000-01-02" in rv.json["result"]["query"] assert "2000-01-04" in rv.json["result"]["query"] assert "col3 = 1.2" in rv.json["result"]["query"] - assert "col4 is null" in rv.json["result"]["query"] + assert "col4 IS NULL" in rv.json["result"]["query"] assert "col2 = 'c'" in rv.json["result"]["query"] diff --git a/tests/integration_tests/db_engine_specs/ascend_tests.py b/tests/integration_tests/db_engine_specs/ascend_tests.py index cd1fa372858..045cac7d76a 100644 --- a/tests/integration_tests/db_engine_specs/ascend_tests.py +++ b/tests/integration_tests/db_engine_specs/ascend_tests.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. from superset.db_engine_specs.ascend import AscendEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestAscendDbEngineSpec(TestDbEngineSpec): +class TestAscendDbEngineSpec(SupersetTestCase): def test_convert_dttm(self): dttm = self.get_dttm() diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index 7dd9c5cb95e..46a5951f10f 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -25,14 +25,13 @@ from superset.db_engine_specs.base import ( BaseEngineSpec, BasicParametersMixin, builtin_time_grains, - LimitMethod, ) from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.sql_parse import Table from superset.utils.database import get_example_database -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.test_app import app from ..fixtures.birth_names_dashboard import ( @@ -46,7 +45,7 @@ from ..fixtures.energy_dashboard import ( from ..fixtures.pyodbcRow import Row -class TestDbEngineSpecs(TestDbEngineSpec): +class SupersetTestCases(SupersetTestCase): def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec): q0 = "select * from table" q1 = "select * from mytable limit 10" @@ -74,124 +73,9 @@ class TestDbEngineSpecs(TestDbEngineSpec): assert engine_spec_class.get_limit_from_sql(q10) is None assert engine_spec_class.get_limit_from_sql(q11) is None - def test_wrapped_semi_tabs(self): - self.sql_limit_regex( - "SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000" - ) - - def test_simple_limit_query(self): - self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000") - - def test_modify_limit_query(self): - self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000") - - def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name - self.sql_limit_regex( - "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999", - "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000", - ) - - def test_limit_query_without_force(self): - self.sql_limit_regex( - "SELECT * FROM a LIMIT 10", - "SELECT * FROM a LIMIT 10", - limit=11, - ) - - def test_limit_query_with_force(self): - self.sql_limit_regex( - "SELECT * FROM a LIMIT 10", - "SELECT * FROM a LIMIT 11", - limit=11, - force=True, - ) - - def test_limit_with_expr(self): - self.sql_limit_regex( - """ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990""", - """SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 1000""", - ) - - def test_limit_expr_and_semicolon(self): - self.sql_limit_regex( - """ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990 ;""", - """SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 1000""", - ) - def test_get_datatype(self): assert "VARCHAR" == BaseEngineSpec.get_datatype("VARCHAR") - def test_limit_with_implicit_offset(self): - self.sql_limit_regex( - """ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990, 999999""", - """SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990, 1000""", - ) - - def test_limit_with_explicit_offset(self): - self.sql_limit_regex( - """ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990 - OFFSET 999999""", - """SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 1000 - OFFSET 999999""", - ) - - def test_limit_with_non_token_limit(self): - self.sql_limit_regex( - """SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000""" - ) - - def test_limit_with_fetch_many(self): - class DummyEngineSpec(BaseEngineSpec): - limit_method = LimitMethod.FETCH_MANY - - self.sql_limit_regex( - "SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec - ) - def test_engine_time_grain_validity(self): time_grains = set(builtin_time_grains.keys()) # loop over all subclasses of BaseEngineSpec diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 636fc3523ae..00b8414127a 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -26,7 +26,7 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.bigquery import BigQueryEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.sql_parse import Table -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, # noqa: F401 load_birth_names_data, # noqa: F401 @@ -42,7 +42,7 @@ def mock_engine_with_credentials(*args, **kwargs): yield engine_mock -class TestBigQueryDbEngineSpec(TestDbEngineSpec): +class TestBigQueryDbEngineSpec(SupersetTestCase): def test_bigquery_sqla_column_label(self): """ DB Eng Specs (bigquery): Test column label diff --git a/tests/integration_tests/db_engine_specs/databricks_tests.py b/tests/integration_tests/db_engine_specs/databricks_tests.py index bf4d7e8b9f9..ec6ed2964ef 100644 --- a/tests/integration_tests/db_engine_specs/databricks_tests.py +++ b/tests/integration_tests/db_engine_specs/databricks_tests.py @@ -18,12 +18,12 @@ from unittest import mock from superset.db_engine_specs import get_engine_spec from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.certificates import ssl_certificate from tests.integration_tests.fixtures.database import default_db_extra -class TestDatabricksDbEngineSpec(TestDbEngineSpec): +class TestDatabricksDbEngineSpec(SupersetTestCase): def test_get_engine_spec(self): """ DB Eng Specs (databricks): Test "databricks" in engine spec diff --git a/tests/integration_tests/db_engine_specs/elasticsearch_tests.py b/tests/integration_tests/db_engine_specs/elasticsearch_tests.py index 8027c031a5d..2ac4f6aa2a8 100644 --- a/tests/integration_tests/db_engine_specs/elasticsearch_tests.py +++ b/tests/integration_tests/db_engine_specs/elasticsearch_tests.py @@ -19,10 +19,10 @@ from sqlalchemy import column from superset.constants import TimeGrain from superset.db_engine_specs.elasticsearch import ElasticSearchEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestElasticsearchDbEngineSpec(TestDbEngineSpec): +class TestElasticsearchDbEngineSpec(SupersetTestCase): @parameterized.expand( [ [TimeGrain.SECOND, "DATE_TRUNC('second', ts)"], diff --git a/tests/integration_tests/db_engine_specs/gsheets_tests.py b/tests/integration_tests/db_engine_specs/gsheets_tests.py index 212af15c333..6368d730d47 100644 --- a/tests/integration_tests/db_engine_specs/gsheets_tests.py +++ b/tests/integration_tests/db_engine_specs/gsheets_tests.py @@ -16,10 +16,10 @@ # under the License. from superset.db_engine_specs.gsheets import GSheetsEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestGsheetsDbEngineSpec(TestDbEngineSpec): +class TestGsheetsDbEngineSpec(SupersetTestCase): def test_extract_errors(self): """ Test that custom error messages are extracted correctly. diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index c6bbfb683ff..05e31bbee55 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -17,7 +17,6 @@ # isort:skip_file from unittest import mock import unittest -from .base_tests import SupersetTestCase import pytest import pandas as pd @@ -26,6 +25,7 @@ from sqlalchemy.sql import select from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3 from superset.exceptions import SupersetException from superset.sql_parse import Table +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.test_app import app diff --git a/tests/integration_tests/db_engine_specs/mysql_tests.py b/tests/integration_tests/db_engine_specs/mysql_tests.py index 23af61f17d9..2698721c651 100644 --- a/tests/integration_tests/db_engine_specs/mysql_tests.py +++ b/tests/integration_tests/db_engine_specs/mysql_tests.py @@ -21,12 +21,12 @@ from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec): +class TestMySQLEngineSpecsDbEngineSpec(SupersetTestCase): @unittest.skipUnless( - TestDbEngineSpec.is_module_installed("MySQLdb"), "mysqlclient not installed" + SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed" ) def test_get_datatype_mysql(self): """Tests related to datatype mapping for MySQL""" diff --git a/tests/integration_tests/db_engine_specs/pinot_tests.py b/tests/integration_tests/db_engine_specs/pinot_tests.py index 66d4865fb81..f6872a2e2bd 100755 --- a/tests/integration_tests/db_engine_specs/pinot_tests.py +++ b/tests/integration_tests/db_engine_specs/pinot_tests.py @@ -17,10 +17,10 @@ from sqlalchemy import column from superset.db_engine_specs.pinot import PinotEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestPinotDbEngineSpec(TestDbEngineSpec): +class TestPinotDbEngineSpec(SupersetTestCase): """Tests pertaining to our Pinot database support""" def test_pinot_time_expression_sec_one_1d_grain(self): diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index e45e0189f40..236d293df4b 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -27,12 +27,12 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query from superset.utils.core import backend from superset.utils.database import get_example_database -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.certificates import ssl_certificate from tests.integration_tests.fixtures.database import default_db_extra -class TestPostgresDbEngineSpec(TestDbEngineSpec): +class TestPostgresDbEngineSpec(SupersetTestCase): def test_get_table_names(self): """ DB Eng Specs (postgres): Test get table names diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index c57bec88008..2a27d12df9a 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -27,11 +27,11 @@ from superset.db_engine_specs.presto import PrestoEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.sql_parse import Table from superset.utils.database import get_example_database -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase -class TestPrestoDbEngineSpec(TestDbEngineSpec): - @skipUnless(TestDbEngineSpec.is_module_installed("pyhive"), "pyhive not installed") +class TestPrestoDbEngineSpec(SupersetTestCase): + @skipUnless(SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed") def test_get_datatype_presto(self): assert "STRING" == PrestoEngineSpec.get_datatype("string") diff --git a/tests/integration_tests/db_engine_specs/redshift_tests.py b/tests/integration_tests/db_engine_specs/redshift_tests.py index 2d46c73fca7..38d3bc091c5 100644 --- a/tests/integration_tests/db_engine_specs/redshift_tests.py +++ b/tests/integration_tests/db_engine_specs/redshift_tests.py @@ -24,11 +24,11 @@ from sqlalchemy.types import NVARCHAR from superset.db_engine_specs.redshift import RedshiftEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.sql_parse import Table -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.test_app import app -class TestRedshiftDbEngineSpec(TestDbEngineSpec): +class TestRedshiftDbEngineSpec(SupersetTestCase): def test_extract_errors(self): """ Test that custom error messages are extracted correctly. diff --git a/tests/integration_tests/log_api_tests.py b/tests/integration_tests/log_api_tests.py index ddec917cb4a..a08068defab 100644 --- a/tests/integration_tests/log_api_tests.py +++ b/tests/integration_tests/log_api_tests.py @@ -98,7 +98,7 @@ class TestLogApi(SupersetTestCase): response = json.loads(rv.data.decode("utf-8")) assert list(response["result"][0].keys()) == EXPECTED_COLUMNS assert response["result"][0]["action"] == "some_action" - assert response["result"][0]["user"] == {"username": "admin"} + assert response["result"][0]["user"]["username"] == "admin" db.session.delete(log) db.session.commit() @@ -132,7 +132,7 @@ class TestLogApi(SupersetTestCase): assert list(response["result"].keys()) == EXPECTED_COLUMNS assert response["result"]["action"] == "some_action" - assert response["result"]["user"] == {"username": "admin"} + assert response["result"]["user"]["username"] == "admin" db.session.delete(log) db.session.commit() diff --git a/tests/integration_tests/log_model_view_tests.py b/tests/integration_tests/log_model_view_tests.py deleted file mode 100644 index e347f39e9a4..00000000000 --- a/tests/integration_tests/log_model_view_tests.py +++ /dev/null @@ -1,37 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from unittest.mock import patch - -from superset.views.log.views import LogModelView -from tests.integration_tests.base_tests import SupersetTestCase -from tests.integration_tests.constants import ADMIN_USERNAME - - -class TestLogModelView(SupersetTestCase): - def test_disabled(self): - with patch.object(LogModelView, "is_enabled", return_value=False): - self.login(ADMIN_USERNAME) - uri = "/logmodelview/list/" - rv = self.client.get(uri) - self.assert404(rv) - - def test_enabled(self): - with patch.object(LogModelView, "is_enabled", return_value=True): - self.login(ADMIN_USERNAME) - uri = "/logmodelview/list/" - rv = self.client.get(uri) - self.assert200(rv) diff --git a/tests/integration_tests/migrations/c747c78868b6_migrating_legacy_treemap__tests.py b/tests/integration_tests/migrations/c747c78868b6_migrating_legacy_treemap__tests.py deleted file mode 100644 index 0fd92761210..00000000000 --- a/tests/integration_tests/migrations/c747c78868b6_migrating_legacy_treemap__tests.py +++ /dev/null @@ -1,91 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from superset.app import SupersetApp -from superset.migrations.shared.migrate_viz import MigrateTreeMap -from superset.utils import json - -treemap_form_data = """{ - "adhoc_filters": [ - { - "clause": "WHERE", - "comparator": [ - "Edward" - ], - "expressionType": "SIMPLE", - "filterOptionName": "filter_xhbus6irfa_r10k9nwmwy", - "isExtra": false, - "isNew": false, - "operator": "IN", - "operatorId": "IN", - "sqlExpression": null, - "subject": "name" - } - ], - "color_scheme": "bnbColors", - "datasource": "2__table", - "extra_form_data": {}, - "granularity_sqla": "ds", - "groupby": [ - "state", - "gender" - ], - "metrics": [ - "sum__num" - ], - "number_format": ",d", - "order_desc": true, - "row_limit": 10, - "time_range": "No filter", - "timeseries_limit_metric": "sum__num", - "treemap_ratio": 1.618033988749895, - "viz_type": "treemap" -} -""" - - -def test_treemap_migrate(app_context: SupersetApp) -> None: - from superset.models.slice import Slice - - slc = Slice( - viz_type=MigrateTreeMap.source_viz_type, - datasource_type="table", - params=treemap_form_data, - query_context=f'{{"form_data": {treemap_form_data}}}', - ) - - MigrateTreeMap.upgrade_slice(slc) - assert slc.viz_type == MigrateTreeMap.target_viz_type - # verify form_data - new_form_data = json.loads(slc.params) - assert new_form_data["metric"] == "sum__num" - assert new_form_data["viz_type"] == "treemap_v2" - assert "metrics" not in new_form_data - assert json.dumps(new_form_data["form_data_bak"], sort_keys=True) == json.dumps( - json.loads(treemap_form_data), sort_keys=True - ) - - # verify query_context - new_query_context = json.loads(slc.query_context) - assert new_query_context["form_data"]["viz_type"] == "treemap_v2" - - # downgrade - MigrateTreeMap.downgrade_slice(slc) - assert slc.viz_type == MigrateTreeMap.source_viz_type - assert json.dumps(json.loads(slc.params), sort_keys=True) == json.dumps( - json.loads(treemap_form_data), sort_keys=True - ) diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index f3d73ae5534..499526cba99 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -447,7 +447,11 @@ class TestSqlaTableModel(SupersetTestCase): return None old_inner_join = spec.allows_joins spec.allows_joins = inner_join - arbitrary_gby = "state || gender || '_test'" + arbitrary_gby = ( + "state OR gender OR '_test'" + if get_example_database().backend == "mysql" + else "state || gender || '_test'" + ) arbitrary_metric = dict( # noqa: C408 label="arbitrary", expressionType="SQL", sqlExpression="SUM(num_boys)" ) diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index a56e8338352..cb5afc98aff 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -876,12 +876,6 @@ def test_special_chars_in_column_name(app_context, physical_dataset): "columns": [ "col1", "time column with spaces", - { - "label": "I_AM_A_TRUNC_COLUMN", - "sqlExpression": "time column with spaces", - "columnType": "BASE_AXIS", - "timeGrain": "P1Y", - }, ], "metrics": ["count"], "orderby": [["col1", True]], @@ -897,10 +891,8 @@ def test_special_chars_in_column_name(app_context, physical_dataset): if query_object.datasource.database.backend == "sqlite": # sqlite returns string as timestamp column assert df["time column with spaces"][0] == "2002-01-03 00:00:00" - assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00" else: assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03" - assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01" @only_postgresql diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index 7d0ea7cc2ef..dad966c6de6 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -33,7 +33,9 @@ from tests.integration_tests.test_app import app from superset import db, sql_lab from superset.common.db_query_status import QueryStatus from superset.models.core import Database # noqa: F401 -from superset.utils.database import get_example_database, get_main_database # noqa: F401 +from superset.utils.database import ( + get_example_database, +) # noqa: F401 from superset.utils import core as utils, json from superset.models.sql_lab import Query @@ -281,7 +283,7 @@ class TestSqlLabApi(SupersetTestCase): "/api/v1/sqllab/format_sql/", json=data, ) - success_resp = {"result": "SELECT 1\nFROM my_table"} + success_resp = {"result": "SELECT\n 1\nFROM my_table"} resp_data = json.loads(rv.data.decode("utf-8")) self.assertDictEqual(resp_data, success_resp) # noqa: PT009 assert rv.status_code == 200 diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 80ada22e0aa..793e2db6652 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -198,7 +198,7 @@ class TestDatabaseModel(SupersetTestCase): # assert dataset saved metric assert "count('bar_P1D')" in query # assert adhoc metric - assert "SUM(case when user = 'user_abc' then 1 else 0 end)" in query + assert "SUM(CASE WHEN user = 'user_abc' THEN 1 ELSE 0 END)" in query # Cleanup db.session.delete(table) db.session.commit() diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 514053f7d1e..99c347d95f9 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -31,16 +31,15 @@ from superset.db_engine_specs import BaseEngineSpec from superset.db_engine_specs.hive import HiveEngineSpec from superset.db_engine_specs.presto import PrestoEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import SupersetErrorException +from superset.exceptions import SupersetErrorException, SupersetInvalidCVASException from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet from superset.sqllab.limiting_factor import LimitingFactor +from superset.sql.parse import CTASMethod from superset.sql_lab import ( cancel_query, execute_sql_statements, - apply_limit_if_exists, ) -from superset.sql_parse import CtasMethod from superset.utils.core import backend from superset.utils import json from superset.utils.json import datetime_to_epoch # noqa: F401 @@ -132,31 +131,13 @@ class TestSqlLab(SupersetTestCase): self.login(ADMIN_USERNAME) data = self.run_sql("DELETE FROM birth_names", "1") - assert data == { - "errors": [ - { - "message": ( - "This database does not allow for DDL/DML, and the query " - "could not be parsed to confirm it is a read-only query. Please " # noqa: E501 - "contact your administrator for more assistance." - ), - "error_type": SupersetErrorType.DML_NOT_ALLOWED_ERROR, - "level": ErrorLevel.ERROR, - "extra": { - "issue_codes": [ - { - "code": 1022, - "message": "Issue 1022 - Database does not allow data manipulation.", # noqa: E501 - } - ] - }, - } - ] - } + assert ( + data["errors"][0]["error_type"] == SupersetErrorType.DML_NOT_ALLOWED_ERROR + ) - @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) + @parameterized.expand([CTASMethod.TABLE, CTASMethod.VIEW]) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_sql_json_cta_dynamic_db(self, ctas_method): + def test_sql_json_cta_dynamic_db(self, ctas_method: CTASMethod) -> None: examples_db = get_example_database() if examples_db.backend == "sqlite": # sqlite doesn't support database creation @@ -170,7 +151,7 @@ class TestSqlLab(SupersetTestCase): examples_db.allow_ctas = True # enable cta self.login(ADMIN_USERNAME) - tmp_table_name = f"test_target_{ctas_method.lower()}" + tmp_table_name = f"test_target_{ctas_method.name.lower()}" self.run_sql( "SELECT * FROM birth_names", "1", @@ -195,7 +176,9 @@ class TestSqlLab(SupersetTestCase): ) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True # cleanup - engine.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}") + engine.execute( + f"DROP {ctas_method.name} admin_database.{tmp_table_name}" + ) examples_db.allow_ctas = old_allow_ctas db.session.commit() @@ -608,10 +591,10 @@ class TestSqlLab(SupersetTestCase): @mock.patch("superset.sql_lab.db") @mock.patch("superset.sql_lab.get_query") - @mock.patch("superset.sql_lab.execute_sql_statement") + @mock.patch("superset.sql_lab.execute_query") def test_execute_sql_statements( self, - mock_execute_sql_statement, + mock_execute_query, mock_get_query, mock_db, ): @@ -623,7 +606,7 @@ class TestSqlLab(SupersetTestCase): """ ) mock_db = mock.MagicMock() # noqa: F841 - mock_query = mock.MagicMock() + mock_query = mock.MagicMock(select_as_cta=False) mock_query.database.allow_run_async = False mock_cursor = mock.MagicMock() mock_query.database.get_raw_connection().__enter__().cursor.return_value = ( @@ -641,30 +624,20 @@ class TestSqlLab(SupersetTestCase): expand_data=False, log_params=None, ) - mock_execute_sql_statement.assert_has_calls( + mock_execute_query.assert_has_calls( [ - mock.call( - "-- comment\nSET @value = 42", - mock_query, - mock_cursor, - None, - False, - ), - mock.call( - "SELECT /*+ hint */ @value AS foo", - mock_query, - mock_cursor, - None, - False, - ), + mock.call(mock_query, mock_cursor, None), + mock.call(mock_query, mock_cursor, None), ] ) @mock.patch("superset.sql_lab.results_backend", None) @mock.patch("superset.sql_lab.get_query") - @mock.patch("superset.sql_lab.execute_sql_statement") + @mock.patch("superset.sql_lab.execute_query") def test_execute_sql_statements_no_results_backend( - self, mock_execute_sql_statement, mock_get_query + self, + mock_execute_query, + mock_get_query, ): sql = dedent( """ @@ -712,10 +685,10 @@ class TestSqlLab(SupersetTestCase): @mock.patch("superset.sql_lab.db") @mock.patch("superset.sql_lab.get_query") - @mock.patch("superset.sql_lab.execute_sql_statement") + @mock.patch("superset.sql_lab.execute_query") def test_execute_sql_statements_ctas( self, - mock_execute_sql_statement, + mock_execute_query, mock_get_query, mock_db, ): @@ -727,7 +700,13 @@ class TestSqlLab(SupersetTestCase): """ ) mock_db = mock.MagicMock() # noqa: F841 - mock_query = mock.MagicMock() + mock_query = mock.MagicMock( + select_as_cta=True, + ctas_method=CTASMethod.TABLE.name, + tmp_table_name="table", + tmp_schema_name="schema", + catalog="catalog", + ) mock_query.database.allow_run_async = False mock_cursor = mock.MagicMock() mock_query.database.get_raw_connection().__enter__().cursor.return_value = ( @@ -738,7 +717,7 @@ class TestSqlLab(SupersetTestCase): # set the query to CTAS mock_query.select_as_cta = True - mock_query.ctas_method = CtasMethod.TABLE + mock_query.ctas_method = CTASMethod.TABLE.name execute_sql_statements( query_id=1, @@ -749,22 +728,10 @@ class TestSqlLab(SupersetTestCase): expand_data=False, log_params=None, ) - mock_execute_sql_statement.assert_has_calls( + mock_execute_query.assert_has_calls( [ - mock.call( - "-- comment\nSET @value = 42", - mock_query, - mock_cursor, - None, - False, - ), - mock.call( - "SELECT /*+ hint */ @value AS foo", - mock_query, - mock_cursor, - None, - True, # apply_ctas - ), + mock.call(mock_query, mock_cursor, None), + mock.call(mock_query, mock_cursor, None), ] ) @@ -795,7 +762,7 @@ class TestSqlLab(SupersetTestCase): ) # try invalid CVAS - mock_query.ctas_method = CtasMethod.VIEW + mock_query.ctas_method = CTASMethod.VIEW.name sql = dedent( """ -- comment @@ -803,7 +770,7 @@ class TestSqlLab(SupersetTestCase): SELECT /*+ hint */ @value AS foo; """ ) - with pytest.raises(SupersetErrorException) as excinfo: + with pytest.raises(SupersetInvalidCVASException) as excinfo: execute_sql_statements( query_id=1, rendered_query=sql, @@ -870,29 +837,6 @@ class TestSqlLab(SupersetTestCase): ] } - def test_apply_limit_if_exists_when_incremented_limit_is_none(self): - sql = """ - SET @value = 42; - SELECT @value AS foo; - """ - database = get_example_database() - mock_query = mock.MagicMock() - mock_query.limit = 300 - final_sql = apply_limit_if_exists(database, None, mock_query, sql) - - assert final_sql == sql - - def test_apply_limit_if_exists_when_increased_limit(self): - sql = """ - SET @value = 42; - SELECT @value AS foo; - """ - database = get_example_database() - mock_query = mock.MagicMock() - mock_query.limit = 300 - final_sql = apply_limit_if_exists(database, 1000, mock_query, sql) - assert "LIMIT 1000" in final_sql - @pytest.mark.parametrize("spec", [HiveEngineSpec, PrestoEngineSpec]) def test_cancel_query_implicit(spec: BaseEngineSpec) -> None: diff --git a/tests/integration_tests/tags/api_tests.py b/tests/integration_tests/tags/api_tests.py index 7f65c862a4b..61e93c6ac98 100644 --- a/tests/integration_tests/tags/api_tests.py +++ b/tests/integration_tests/tags/api_tests.py @@ -14,33 +14,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# isort:skip_file """Unit tests for Superset""" -import prison from datetime import datetime - -from flask import g # noqa: F401 -import pytest -import prison # noqa: F811 -from freezegun import freeze_time -from sqlalchemy.sql import func -from sqlalchemy import and_ # noqa: F401 -from superset.models.dashboard import Dashboard -from superset.models.slice import Slice -from superset.models.sql_lab import SavedQuery # noqa: F401 -from superset.tags.models import user_favorite_tag_table # noqa: F401 from unittest.mock import patch from urllib import parse +import prison +import pytest +from freezegun import freeze_time +from sqlalchemy import and_ +from sqlalchemy.sql import func -import tests.integration_tests.test_app # noqa: F401 -from superset import db, security_manager # noqa: F401 -from superset.common.db_query_status import QueryStatus # noqa: F401 -from superset.models.core import Database # noqa: F401 -from superset.utils.database import get_example_database, get_main_database # noqa: F401 +from superset import db +from superset.connectors.sqla.models import SqlaTable +from superset.daos.tag import TagDAO +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.tags.models import ( + ObjectType, + Tag, + TaggedObject, + TagType, + user_favorite_tag_table, +) from superset.utils import json -from superset.tags.models import ObjectType, Tag, TagType, TaggedObject +from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.constants import ADMIN_USERNAME, ALPHA_USERNAME from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, # noqa: F401 @@ -50,10 +49,7 @@ from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, # noqa: F401 load_world_bank_data, # noqa: F401 ) -from tests.integration_tests.fixtures.tags import with_tagging_system_feature # noqa: F401 -from tests.integration_tests.base_tests import SupersetTestCase -from superset.daos.tag import TagDAO -from superset.tags.models import ObjectType # noqa: F811 +from tests.integration_tests.insert_chart_mixin import InsertChartMixin TAGS_FIXTURE_COUNT = 10 @@ -71,7 +67,7 @@ TAGS_LIST_COLUMNS = [ ] -class TestTagApi(SupersetTestCase): +class TestTagApi(InsertChartMixin, SupersetTestCase): def insert_tag( self, name: str, @@ -406,6 +402,96 @@ class TestTagApi(SupersetTestCase): # clean up tagged object tagged_objects.delete() + def test_get_tagged_objects_restricted(self): + """ + Test that the get_objects endpoint returns only assets + the user has access to. + """ + owner = self.get_user(ADMIN_USERNAME) + + # Create a tag + tag = self.insert_tag( + name="test_tagged_objects_visibility", + tag_type="custom", + ) + + # Create a chart + chart_first_dataset = self.insert_chart("first_chart", [owner.id], 1) + first_tag_relation = self.insert_tagged_object( + tag_id=tag.id, + object_id=chart_first_dataset.id, + object_type=ObjectType.chart, + ) + + # Create another chart and add it to a dashboard + chart_second_dataset = self.insert_chart("second_chart", [owner.id], 2) + second_tag_relation = self.insert_tagged_object( + tag_id=tag.id, + object_id=chart_second_dataset.id, + object_type=ObjectType.chart, + ) + dashboard = self.insert_dashboard( + "test_dashboard", + "test_dashboard", + [owner.id], + slices=[chart_second_dataset], + published=True, + ) + dashboard_tag_relation = self.insert_tagged_object( + tag_id=tag.id, + object_id=dashboard.id, + object_type=ObjectType.dashboard, + ) + + # Create a user without access to these items + user = self.create_user_with_roles( + "test_restricted_user", + ["testing_new_role"], + should_create_roles=True, + ) + self.login("test_restricted_user") + + uri = f"api/v1/tag/get_objects/?tagIds={tag.id}" + rv = self.client.get(uri) + assert rv.status_code == 200 + assert rv.json["result"] == [] + + # grant access to dataset ID 1 + first_dataset = db.session.query(SqlaTable).filter(SqlaTable.id == 1).first() + self.grant_role_access_to_table(first_dataset, "testing_new_role") + + rv = self.client.get(uri) + assert rv.status_code == 200 + result = rv.json["result"] + assert len(result) == 1 + assert result[0]["id"] == chart_first_dataset.id + + # grant access to dataset ID 2 + second_dataset = db.session.query(SqlaTable).filter(SqlaTable.id == 2).first() + self.grant_role_access_to_table(second_dataset, "testing_new_role") + + rv = self.client.get(uri) + assert rv.status_code == 200 + result = rv.json["result"] + assert len(result) == 3 + assert sorted([res["id"] for res in result]) == sorted( + [chart_first_dataset.id, chart_second_dataset.id, dashboard.id] + ) + + # Clean up + db.session.delete(dashboard_tag_relation) + db.session.delete(dashboard) + db.session.delete(second_tag_relation) + db.session.delete(chart_second_dataset) + db.session.delete(first_tag_relation) + db.session.delete(chart_first_dataset) + db.session.delete(tag) + self.revoke_role_access_to_table("testing_new_role", first_dataset) + self.revoke_role_access_to_table("testing_new_role", second_dataset) + db.session.delete(user.roles[0]) + db.session.delete(user) + db.session.commit() + # test delete tags @pytest.mark.usefixtures("create_tags") def test_delete_tags(self): @@ -443,9 +529,6 @@ class TestTagApi(SupersetTestCase): rv = self.client.post(uri, follow_redirects=True) assert rv.status_code == 200 - from sqlalchemy import and_ # noqa: F811 - from superset.tags.models import user_favorite_tag_table # noqa: F811 - from flask import g # noqa: F401, F811 association_row = ( db.session.query(user_favorite_tag_table) @@ -630,10 +713,10 @@ class TestTagApi(SupersetTestCase): assert rv.status_code == 200 - result = TagDAO.get_tagged_objects_for_tags(tags, ["dashboard"]) + result = TagDAO.get_tagged_objects_by_tag_names(tags, ["dashboard"]) assert len(result) == 1 - result = TagDAO.get_tagged_objects_for_tags(tags, ["chart"]) + result = TagDAO.get_tagged_objects_by_tag_names(tags, ["chart"]) assert len(result) == 1 tagged_objects = ( diff --git a/tests/integration_tests/tags/dao_tests.py b/tests/integration_tests/tags/dao_tests.py index 16e09f41494..0bc866da035 100644 --- a/tests/integration_tests/tags/dao_tests.py +++ b/tests/integration_tests/tags/dao_tests.py @@ -14,27 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# isort:skip_file from operator import and_ -from unittest.mock import patch # noqa: F401 -import pytest -from superset.models.slice import Slice -from superset.models.sql_lab import SavedQuery # noqa: F401 -from superset.daos.tag import TagDAO -from superset.tags.exceptions import InvalidTagNameError # noqa: F401 -from superset.tags.models import ObjectType, Tag, TaggedObject -from tests.integration_tests.tags.api_tests import TAGS_FIXTURE_COUNT -import tests.integration_tests.test_app # pylint: disable=unused-import # noqa: F401 -from superset import db, security_manager # noqa: F401 -from superset.daos.dashboard import DashboardDAO # noqa: F401 +import pytest + +from superset import db +from superset.daos.tag import TagDAO from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from superset.tags.models import ObjectType, Tag, TaggedObject from tests.integration_tests.base_tests import SupersetTestCase +from tests.integration_tests.constants import ADMIN_USERNAME +from tests.integration_tests.fixtures.tags import ( + with_tagging_system_feature, # noqa: F401 +) from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, # noqa: F401 load_world_bank_data, # noqa: F401 ) -from tests.integration_tests.fixtures.tags import with_tagging_system_feature # noqa: F401 +from tests.integration_tests.tags.api_tests import TAGS_FIXTURE_COUNT class TestTagsDAO(SupersetTestCase): @@ -151,6 +149,7 @@ class TestTagsDAO(SupersetTestCase): @pytest.mark.usefixtures("create_tags") # test get objects from tag def test_get_objects_from_tag(self): + self.login(ADMIN_USERNAME) # create tagged objects dashboard = ( db.session.query(Dashboard) @@ -163,17 +162,17 @@ class TestTagsDAO(SupersetTestCase): object_id=dashboard_id, object_type=ObjectType.dashboard, tag_id=tag.id ) # get objects - tagged_objects = TagDAO.get_tagged_objects_for_tags( + tagged_objects = TagDAO.get_tagged_objects_by_tag_names( ["example_tag_1", "example_tag_2"] ) assert len(tagged_objects) == 1 # test get objects from tag with type - tagged_objects = TagDAO.get_tagged_objects_for_tags( + tagged_objects = TagDAO.get_tagged_objects_by_tag_names( ["example_tag_1", "example_tag_2"], obj_types=["dashboard", "chart"] ) assert len(tagged_objects) == 1 - tagged_objects = TagDAO.get_tagged_objects_for_tags( + tagged_objects = TagDAO.get_tagged_objects_by_tag_names( ["example_tag_1", "example_tag_2"], obj_types=["chart"] ) assert len(tagged_objects) == 0 @@ -206,12 +205,12 @@ class TestTagsDAO(SupersetTestCase): + num_charts ) # gets all tagged objects of type dashboard and chart - tagged_objects = TagDAO.get_tagged_objects_for_tags( + tagged_objects = TagDAO.get_tagged_objects_by_tag_names( obj_types=["dashboard", "chart"] ) assert len(tagged_objects) == num_charts_and_dashboards # test objects are retrieved by type - tagged_objects = TagDAO.get_tagged_objects_for_tags(obj_types=["chart"]) + tagged_objects = TagDAO.get_tagged_objects_by_tag_names(obj_types=["chart"]) assert len(tagged_objects) == num_charts @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") @@ -219,6 +218,7 @@ class TestTagsDAO(SupersetTestCase): @pytest.mark.usefixtures("create_tags") # test get objects from tag def test_get_objects_from_tag_with_id(self): + self.login(ADMIN_USERNAME) # create tagged objects dashboard = ( db.session.query(Dashboard) @@ -233,16 +233,16 @@ class TestTagsDAO(SupersetTestCase): object_id=dashboard_id, object_type=ObjectType.dashboard, tag_id=tag_1.id ) # get objects - tagged_objects = TagDAO.get_tagged_objects_by_tag_id(tag_ids) + tagged_objects = TagDAO.get_tagged_objects_by_tag_ids(tag_ids) assert len(tagged_objects) == 1 # test get objects from tag with type - tagged_objects = TagDAO.get_tagged_objects_by_tag_id( + tagged_objects = TagDAO.get_tagged_objects_by_tag_ids( tag_ids, obj_types=["dashboard", "chart"] ) assert len(tagged_objects) == 1 - tagged_objects = TagDAO.get_tagged_objects_by_tag_id( + tagged_objects = TagDAO.get_tagged_objects_by_tag_ids( tag_ids, obj_types=["chart"] ) assert len(tagged_objects) == 0 diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index c100007d383..421be19eb6d 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -49,32 +49,6 @@ def test_get_text_clause_with_colon() -> None: assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')" -def test_parse_sql_single_statement() -> None: - """ - `parse_sql` should properly strip leading and trailing spaces and semicolons - """ - - from superset.db_engine_specs.base import BaseEngineSpec - - queries = BaseEngineSpec.parse_sql(" SELECT foo FROM tbl ; ") - assert queries == ["SELECT foo FROM tbl"] - - -def test_parse_sql_multi_statement() -> None: - """ - For string with multiple SQL-statements `parse_sql` method should return list - where each element represents the single SQL-statement - """ - - from superset.db_engine_specs.base import BaseEngineSpec - - queries = BaseEngineSpec.parse_sql("SELECT foo FROM tbl1; SELECT bar FROM tbl2;") - assert queries == [ - "SELECT foo FROM tbl1", - "SELECT bar FROM tbl2", - ] - - def test_validate_db_uri(mocker: MockerFixture) -> None: """ Ensures that the `validate_database_uri` method invokes the validator correctly @@ -206,9 +180,6 @@ def test_select_star(mocker: MockerFixture) -> None: """ from superset.db_engine_specs.base import BaseEngineSpec - class NoLimitDBEngineSpec(BaseEngineSpec): - allow_limit_clause = False - cols: list[ResultSetColumnType] = [ { "column_name": "a", @@ -243,19 +214,7 @@ def test_select_star(mocker: MockerFixture) -> None: latest_partition=False, cols=cols, ) - assert sql == "SELECT a\nFROM my_table\nLIMIT ?\nOFFSET ?" - - sql = NoLimitDBEngineSpec.select_star( - database=database, - table=Table("my_table"), - engine=engine, - limit=100, - show_cols=True, - indent=True, - latest_partition=False, - cols=cols, - ) - assert sql == "SELECT a\nFROM my_table" + assert sql == "SELECT\n a\nFROM my_table\nLIMIT ?\nOFFSET ?" def test_extra_table_metadata(mocker: MockerFixture) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_kusto.py b/tests/unit_tests/db_engine_specs/test_kusto.py index e8759f38cf4..a21e82a5676 100644 --- a/tests/unit_tests/db_engine_specs/test_kusto.py +++ b/tests/unit_tests/db_engine_specs/test_kusto.py @@ -23,7 +23,6 @@ from sqlalchemy import column from superset.db_engine_specs.kusto import KustoKqlEngineSpec from superset.sql.parse import SQLScript -from superset.sql_parse import ParsedQuery from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm # noqa: F401 @@ -53,26 +52,6 @@ def test_sql_has_mutation(sql: str, expected: bool) -> None: ) -@pytest.mark.parametrize( - "kql,expected", - [ - ("tbl | limit 100", True), - ("let foo = 1; tbl | where bar == foo", True), - (".show tables", False), - ], -) -def test_kql_is_select_query(kql: str, expected: bool) -> None: - """ - Make sure that KQL dialect consider only statements that do not start with "." (dot) - as a SELECT statements - """ - - from superset.db_engine_specs.kusto import KustoKqlEngineSpec - - parsed_query = ParsedQuery(kql) - assert KustoKqlEngineSpec.is_select_query(parsed_query) == expected - - @pytest.mark.parametrize( "kql,expected", [ @@ -101,19 +80,6 @@ def test_kql_has_mutation(kql: str, expected: bool) -> None: ) -def test_kql_parse_sql() -> None: - """ - parse_sql method should always return a list with a single element - which is an original query - """ - - from superset.db_engine_specs.kusto import KustoKqlEngineSpec - - queries = KustoKqlEngineSpec.parse_sql("let foo = 1; tbl | where bar == foo") - - assert queries == ["let foo = 1; tbl | where bar == foo"] - - @pytest.mark.parametrize( "target_type,expected_result", [ diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 35fa91108ab..e0ce5e1180c 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -217,17 +217,21 @@ select 'EUR' as cur select * from currency union all select * from currency_2 """ ), - dedent( - """WITH currency as ( -select 'INR' as cur -), -currency_2 as ( -select 'EUR' as cur -), -__cte AS ( -select * from currency union all select * from currency_2 -)""" - ), + """WITH currency AS ( + SELECT + 'INR' AS cur +), currency_2 AS ( + SELECT + 'EUR' AS cur +), __cte AS ( + SELECT + * + FROM currency + UNION ALL + SELECT + * + FROM currency_2 +)""", ), ( "SELECT 1 as cnt", @@ -254,36 +258,6 @@ def test_cte_query_parsing(original: TypeEngine, expected: str) -> None: assert actual == expected -@pytest.mark.parametrize( - "original,expected,top", - [ - ("SEL TOP 1000 * FROM My_table", "SEL TOP 100 * FROM My_table", 100), - ("SEL TOP 1000 * FROM My_table;", "SEL TOP 100 * FROM My_table", 100), - ("SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table", 10000), - ("SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table", 1000), - ( - """with abc as (select * from test union select * from test1) -select TOP 100 * from currency""", - """WITH abc as (select * from test union select * from test1) -select TOP 100 * from currency""", - 1000, - ), - ("SELECT DISTINCT x from tbl", "SELECT DISTINCT TOP 100 x from tbl", 100), - ("SELECT 1 as cnt", "SELECT TOP 10 1 as cnt", 10), - ( - "select TOP 1000 * from abc where id=1", - "select TOP 10 * from abc where id=1", - 10, - ), - ], -) -def test_top_query_parsing(original: TypeEngine, expected: str, top: int) -> None: - from superset.db_engine_specs.mssql import MssqlEngineSpec - - actual = MssqlEngineSpec.apply_top_to_sql(original, top) - assert actual == expected - - def test_extract_errors() -> None: """ Test that custom error messages are extracted correctly. diff --git a/tests/unit_tests/db_engine_specs/test_teradata.py b/tests/unit_tests/db_engine_specs/test_teradata.py deleted file mode 100644 index eab03e040d5..00000000000 --- a/tests/unit_tests/db_engine_specs/test_teradata.py +++ /dev/null @@ -1,43 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument, import-outside-toplevel, protected-access -import pytest - - -@pytest.mark.parametrize( - "limit,original,expected", - [ - (100, "SEL TOP 1000 * FROM My_table", "SEL TOP 100 * FROM My_table"), - (100, "SEL TOP 1000 * FROM My_table;", "SEL TOP 100 * FROM My_table"), - (10000, "SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table"), - (1000, "SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table"), - (100, "SELECT TOP 1000 * FROM My_table", "SELECT TOP 100 * FROM My_table"), - (100, "SEL SAMPLE 1000 * FROM My_table", "SEL SAMPLE 100 * FROM My_table"), - (10000, "SEL SAMPLE 1000 * FROM My_table", "SEL SAMPLE 1000 * FROM My_table"), - ], -) -def test_apply_top_to_sql_limit( - limit: int, - original: str, - expected: str, -) -> None: - """ - Ensure limits are applied to the query correctly - """ - from superset.db_engine_specs.teradata import TeradataEngineSpec - - assert TeradataEngineSpec.apply_top_to_sql(original, limit) == expected diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index fa79cc04936..4cb8cbec2d2 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -31,7 +31,12 @@ from sqlalchemy.dialects.postgresql import dialect from superset import app from superset.commands.dataset.exceptions import DatasetNotFoundError -from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.connectors.sqla.models import ( + RowLevelSecurityFilter, + SqlaTable, + SqlMetric, + TableColumn, +) from superset.exceptions import SupersetTemplateException from superset.jinja_context import ( dataset_macro, @@ -46,6 +51,7 @@ from superset.jinja_context import ( from superset.models.core import Database from superset.models.slice import Slice from superset.utils import json +from tests.unit_tests.conftest import with_feature_flags def test_filter_values_adhoc_filters() -> None: @@ -355,16 +361,29 @@ def test_safe_proxy_nested_lambda() -> None: safe_proxy(func, {"foo": lambda: "bar"}) -def test_user_macros(mocker: MockerFixture): +@pytest.mark.parametrize( + "add_to_cache_keys,mock_cache_key_wrapper_call_count", + [ + (True, 4), + (False, 0), + ], +) +def test_user_macros( + mocker: MockerFixture, + add_to_cache_keys: bool, + mock_cache_key_wrapper_call_count: int, +): """ Test all user macros: - ``current_user_id`` - ``current_username`` - ``current_user_email`` - ``current_user_roles`` + - ``current_user_rls_rules`` """ mock_g = mocker.patch("superset.utils.core.g") mock_get_user_roles = mocker.patch("superset.security_manager.get_user_roles") + mock_get_user_rls = mocker.patch("superset.security_manager.get_rls_filters") mock_cache_key_wrapper = mocker.patch( "superset.jinja_context.ExtraCache.cache_key_wrapper" ) @@ -372,36 +391,20 @@ def test_user_macros(mocker: MockerFixture): mock_g.user.username = "my_username" mock_g.user.email = "my_email@test.com" mock_get_user_roles.return_value = [Role(name="my_role1"), Role(name="my_role2")] - cache = ExtraCache() - assert cache.current_user_id() == 1 - assert cache.current_username() == "my_username" - assert cache.current_user_email() == "my_email@test.com" - assert cache.current_user_roles() == ["my_role1", "my_role2"] - assert mock_cache_key_wrapper.call_count == 4 + mock_get_user_rls.return_value = [ + RowLevelSecurityFilter(group_key="test", clause="1=1"), + RowLevelSecurityFilter(group_key="other_test", clause="product_id=1"), + ] + cache = ExtraCache(table=mocker.MagicMock()) + assert cache.current_user_id(add_to_cache_keys) == 1 + assert cache.current_username(add_to_cache_keys) == "my_username" + assert cache.current_user_email(add_to_cache_keys) == "my_email@test.com" + assert cache.current_user_roles(add_to_cache_keys) == ["my_role1", "my_role2"] + assert mock_cache_key_wrapper.call_count == mock_cache_key_wrapper_call_count - mock_get_user_roles.return_value = [] - assert cache.current_user_roles() is None - - -def test_user_macros_without_cache_key_inclusion(mocker: MockerFixture): - """ - Test all user macros with ``add_to_cache_keys`` set to ``False``. - """ - mock_g = mocker.patch("superset.utils.core.g") - mock_get_user_roles = mocker.patch("superset.security_manager.get_user_roles") - mock_cache_key_wrapper = mocker.patch( - "superset.jinja_context.ExtraCache.cache_key_wrapper" - ) - mock_g.user.id = 1 - mock_g.user.username = "my_username" - mock_g.user.email = "my_email@test.com" - mock_get_user_roles.return_value = [Role(name="my_role1"), Role(name="my_role2")] - cache = ExtraCache() - assert cache.current_user_id(False) == 1 - assert cache.current_username(False) == "my_username" - assert cache.current_user_email(False) == "my_email@test.com" - assert cache.current_user_roles(False) == ["my_role1", "my_role2"] - assert mock_cache_key_wrapper.call_count == 0 + # Testing {{ current_user_rls_rules() }} macro isolated and always without + # the param because it does not support it to avoid shared cache. + assert cache.current_user_rls_rules() == ["1=1", "product_id=1"] def test_user_macros_without_user_info(mocker: MockerFixture): @@ -410,11 +413,55 @@ def test_user_macros_without_user_info(mocker: MockerFixture): """ mock_g = mocker.patch("superset.utils.core.g") mock_g.user = None + cache = ExtraCache(table=mocker.MagicMock()) + assert cache.current_user_id() is None + assert cache.current_username() is None + assert cache.current_user_email() is None + assert cache.current_user_roles() is None + assert cache.current_user_rls_rules() is None + + +def test_current_user_rls_rules_with_no_table(mocker: MockerFixture): + """ + Test the ``current_user_rls_rules`` macro when no table is provided. + """ + mock_g = mocker.patch("superset.utils.core.g") + mock_get_user_rls = mocker.patch("superset.security_manager.get_rls_filters") + mock_is_guest_user = mocker.patch("superset.security_manager.is_guest_user") + mock_cache_key_wrapper = mocker.patch( + "superset.jinja_context.ExtraCache.cache_key_wrapper" + ) + mock_g.user.id = 1 + mock_g.user.username = "my_username" + mock_g.user.email = "my_email@test.com" cache = ExtraCache() - assert cache.current_user_id() == None # noqa: E711 - assert cache.current_username() == None # noqa: E711 - assert cache.current_user_email() == None # noqa: E711 - assert cache.current_user_roles() == None # noqa: E711 + assert cache.current_user_rls_rules() is None + assert mock_cache_key_wrapper.call_count == 0 + assert mock_get_user_rls.call_count == 0 + assert mock_is_guest_user.call_count == 0 + + +@with_feature_flags(EMBEDDED_SUPERSET=True) +def test_current_user_rls_rules_guest_user(mocker: MockerFixture): + """ + Test the ``current_user_rls_rules`` with an embedded user. + """ + mock_g = mocker.patch("superset.utils.core.g") + mock_gg = mocker.patch("superset.tasks.utils.g") + mock_ggg = mocker.patch("superset.security.manager.g") + mock_get_user_rls = mocker.patch("superset.security_manager.get_guest_rls_filters") + mock_user = mocker.MagicMock() + mock_user.username = "my_username" + mock_user.is_guest_user = True + mock_user.is_anonymous = False + mock_g.user = mock_gg.user = mock_ggg.user = mock_user + + mock_get_user_rls.return_value = [ + {"group_key": "test", "clause": "1=1"}, + {"group_key": "other_test", "clause": "product_id=1"}, + ] + cache = ExtraCache(table=mocker.MagicMock()) + assert cache.current_user_rls_rules() == ["1=1", "product_id=1"] def test_where_in() -> None: diff --git a/tests/unit_tests/migrations/viz/dual_line_to_mixed_chart_test.py b/tests/unit_tests/migrations/viz/dual_line_to_mixed_chart_test.py index 3d9dc531224..b7d84b1df81 100644 --- a/tests/unit_tests/migrations/viz/dual_line_to_mixed_chart_test.py +++ b/tests/unit_tests/migrations/viz/dual_line_to_mixed_chart_test.py @@ -30,6 +30,7 @@ ADHOC_FILTERS = [ ] SOURCE_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "metric": "num_boys", "y_axis_format": ",d", "y_axis_bounds": [50, 100], @@ -44,6 +45,7 @@ SOURCE_FORM_DATA: dict[str, Any] = { } TARGET_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "metrics": ["num_boys"], "y_axis_format": ",d", "y_axis_bounds": [50, 100], diff --git a/tests/unit_tests/migrations/viz/heatmap_v1_v2_test.py b/tests/unit_tests/migrations/viz/heatmap_v1_v2_test.py index e343df904f7..50b89fe119a 100644 --- a/tests/unit_tests/migrations/viz/heatmap_v1_v2_test.py +++ b/tests/unit_tests/migrations/viz/heatmap_v1_v2_test.py @@ -20,10 +20,15 @@ from superset.migrations.shared.migrate_viz import MigrateHeatmapChart from tests.unit_tests.migrations.viz.utils import migrate_and_assert SOURCE_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "any_other_key": "untouched", - "all_columns_x": ["category"], - "all_columns_y": ["product"], - "metric": ["sales"], + "all_columns_x": "category", + "all_columns_y": "product", + "metric": { + "label": "sales", + "expressionType": "SQL", + "sqlExpression": "max(sales)", + }, "adhoc_filters": [], "row_limit": 100, "sort_by_metric": True, @@ -47,10 +52,15 @@ SOURCE_FORM_DATA: dict[str, Any] = { } TARGET_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "any_other_key": "untouched", - "x_axis": ["category"], - "groupby": ["product"], - "metric": ["sales"], + "x_axis": "category", + "groupby": "product", + "metric": { + "label": "sales", + "expressionType": "SQL", + "sqlExpression": "max(sales)", + }, "adhoc_filters": [], "row_limit": 100, "legend_type": "continuous", diff --git a/tests/unit_tests/migrations/viz/histogram_v1_v2_test.py b/tests/unit_tests/migrations/viz/histogram_v1_v2_test.py index 8b63263ac4d..ff6de6bc62b 100644 --- a/tests/unit_tests/migrations/viz/histogram_v1_v2_test.py +++ b/tests/unit_tests/migrations/viz/histogram_v1_v2_test.py @@ -20,6 +20,7 @@ from superset.migrations.shared.migrate_viz import MigrateHistogramChart from tests.unit_tests.migrations.viz.utils import migrate_and_assert SOURCE_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "all_columns_x": ["category"], "adhoc_filters": [], "cumulative": True, @@ -33,6 +34,7 @@ SOURCE_FORM_DATA: dict[str, Any] = { } TARGET_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "adhoc_filters": [], "bins": 5, "column": "category", diff --git a/tests/unit_tests/migrations/viz/nvd3_bubble_chart_to_echarts_test.py b/tests/unit_tests/migrations/viz/nvd3_bubble_chart_to_echarts_test.py index 070083b7ae1..2f2f3c78d3d 100644 --- a/tests/unit_tests/migrations/viz/nvd3_bubble_chart_to_echarts_test.py +++ b/tests/unit_tests/migrations/viz/nvd3_bubble_chart_to_echarts_test.py @@ -20,6 +20,7 @@ from superset.migrations.shared.migrate_viz import MigrateBubbleChart from tests.unit_tests.migrations.viz.utils import migrate_and_assert SOURCE_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "adhoc_filters": [], "bottom_margin": 20, "color_scheme": "default", @@ -29,7 +30,7 @@ SOURCE_FORM_DATA: dict[str, Any] = { "max_bubble_size": 50, "series": ["region"], "show_legend": True, - "size": 75, + "size": {"label": "sales", "expressionType": "SQL", "sqlExpression": "max(sales)"}, "viz_type": "bubble", "x": "year", "x_axis_format": "SMART_DATE", @@ -46,6 +47,7 @@ SOURCE_FORM_DATA: dict[str, Any] = { } TARGET_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "adhoc_filters": [], "color_scheme": "default", "entity": "count", @@ -56,7 +58,7 @@ TARGET_FORM_DATA: dict[str, Any] = { "row_limit": 100, "series": ["region"], "show_legend": True, - "size": 75, + "size": {"label": "sales", "expressionType": "SQL", "sqlExpression": "max(sales)"}, "truncateYAxis": True, "viz_type": "bubble_v2", "x": "year", diff --git a/tests/unit_tests/migrations/viz/pivot_table_v1_v2_test.py b/tests/unit_tests/migrations/viz/pivot_table_v1_v2_test.py index 788fd14770e..63c7c8f3f02 100644 --- a/tests/unit_tests/migrations/viz/pivot_table_v1_v2_test.py +++ b/tests/unit_tests/migrations/viz/pivot_table_v1_v2_test.py @@ -20,6 +20,7 @@ from superset.migrations.shared.migrate_viz import MigratePivotTable from tests.unit_tests.migrations.viz.utils import migrate_and_assert SOURCE_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "any_other_key": "untouched", "columns": ["state"], "combine_metric": True, @@ -33,6 +34,7 @@ SOURCE_FORM_DATA: dict[str, Any] = { } TARGET_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "any_other_key": "untouched", "aggregateFunction": "Sum", "colTotals": True, diff --git a/tests/unit_tests/migrations/viz/time_related_fields_test.py b/tests/unit_tests/migrations/viz/time_related_fields_test.py index c5a94a6adca..c3f48f1c333 100644 --- a/tests/unit_tests/migrations/viz/time_related_fields_test.py +++ b/tests/unit_tests/migrations/viz/time_related_fields_test.py @@ -20,12 +20,14 @@ from superset.migrations.shared.migrate_viz import MigratePivotTable from tests.unit_tests.migrations.viz.utils import migrate_and_assert SOURCE_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "granularity_sqla": "ds", "time_range": "1925-04-24 : 2025-04-24", "viz_type": "pivot_table", } TARGET_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "form_data_bak": SOURCE_FORM_DATA, "granularity_sqla": "ds", "rowOrder": "value_z_to_a", diff --git a/tests/unit_tests/migrations/viz/utils.py b/tests/unit_tests/migrations/viz/utils.py index d8eeb833e53..e012fc41cdd 100644 --- a/tests/unit_tests/migrations/viz/utils.py +++ b/tests/unit_tests/migrations/viz/utils.py @@ -20,6 +20,7 @@ from superset.migrations.shared.migrate_viz import MigrateViz from superset.utils import json TIMESERIES_SOURCE_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "bottom_margin": 20, "comparison_type": "absolute", "contribution": True, @@ -42,6 +43,7 @@ TIMESERIES_SOURCE_FORM_DATA: dict[str, Any] = { } TIMESERIES_TARGET_FORM_DATA: dict[str, Any] = { + "datasource": "1__table", "comparison_type": "difference", "contributionMode": "row", "logAxis": True, @@ -75,7 +77,7 @@ def migrate_and_assert( viz_type=cls.source_viz_type, datasource_type="table", params=dumped_form_data, - query_context=f'{{"form_data": {dumped_form_data}}}', + query_context=f'{{"form_data": {dumped_form_data}, "queries": []}}', ) # upgrade @@ -83,6 +85,7 @@ def migrate_and_assert( # verify form_data new_form_data = json.loads(slc.params) + new_form_data.pop("queries_bak", None) assert new_form_data == target assert new_form_data["form_data_bak"] == source diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 50ccde66055..981fab3e8b6 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -38,6 +38,7 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.errors import SupersetErrorType from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.models.core import Database +from superset.sql.parse import LimitMethod from superset.sql_parse import Table from superset.utils import json from tests.unit_tests.conftest import with_feature_flags @@ -910,3 +911,144 @@ def test_get_all_view_names_in_schema(mocker: MockerFixture) -> None: ("third_view", "public", "examples"), } ) + + +@pytest.mark.parametrize( + "sql, limit, force, method, expected", + [ + ( + "SELECT * FROM table", + 100, + False, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM table\nLIMIT 100", + ), + ( + "SELECT * FROM table LIMIT 100", + 10, + False, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM table\nLIMIT 10", + ), + ( + "SELECT * FROM table LIMIT 10", + 100, + False, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM table\nLIMIT 10", + ), + ( + "SELECT * FROM table LIMIT 10", + 100, + True, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM table\nLIMIT 100", + ), + ( + "SELECT * FROM a \t \n ; \t \n ", + 1000, + False, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM a\nLIMIT 1000", + ), + ( + "SELECT 'LIMIT 777'", + 1000, + False, + LimitMethod.FORCE_LIMIT, + "SELECT\n 'LIMIT 777'\nLIMIT 1000", + ), + ( + "SELECT * FROM table", + 1000, + False, + LimitMethod.FETCH_MANY, + "SELECT\n *\nFROM table", + ), + ( + "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999", + 1000, + False, + LimitMethod.FORCE_LIMIT, + """SELECT + * +FROM ( + SELECT + * + FROM a + LIMIT 10 +) +LIMIT 1000""", + ), + ( + """ +SELECT + 'LIMIT 777' AS a + , b +FROM + table +LIMIT 99990""", + 1000, + None, + LimitMethod.FORCE_LIMIT, + "SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000", + ), + ( + """ +SELECT + 'LIMIT 777' AS a + , b +FROM +table +LIMIT 99990 ;""", + 1000, + None, + LimitMethod.FORCE_LIMIT, + "SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000", + ), + ( + """ +SELECT + 'LIMIT 777' AS a + , b +FROM +table +LIMIT 99990, 999999""", + 1000, + None, + LimitMethod.FORCE_LIMIT, + "SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000\nOFFSET 99990", + ), + ( + """ +SELECT + 'LIMIT 777' AS a + , b +FROM +table +LIMIT 99990 +OFFSET 999999""", + 1000, + None, + LimitMethod.FORCE_LIMIT, + "SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000\nOFFSET 999999", + ), + ], +) +def test_apply_limit_to_sql( + sql: str, + limit: int, + force: bool, + method: LimitMethod, + expected: str, + mocker: MockerFixture, +) -> None: + """ + Test the `apply_limit_to_sql` method. + """ + db = Database(database_name="test_database", sqlalchemy_uri="sqlite://") + db_engine_spec = mocker.MagicMock(limit_method=method) + db.get_db_engine_spec = mocker.MagicMock(return_value=db_engine_spec) + + limited = db.apply_limit_to_sql(sql, limit, force) + assert limited == expected diff --git a/tests/unit_tests/security/api_test.py b/tests/unit_tests/security/api_test.py index 39ac115c8f4..869b3308478 100644 --- a/tests/unit_tests/security/api_test.py +++ b/tests/unit_tests/security/api_test.py @@ -29,6 +29,7 @@ def test_csrf_not_exempt(app_context: None) -> None: Test that REST API is not exempt from CSRF. """ assert {blueprint.name for blueprint in csrf._exempt_blueprints} == { + "GroupApi", "MenuApi", "SecurityApi", "OpenApi", diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 8dc06aeea39..df837518b52 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -18,12 +18,15 @@ import pytest -from sqlglot import Dialects +from sqlglot import Dialects, parse_one from superset.exceptions import SupersetParseError from superset.sql.parse import ( + CTASMethod, extract_tables_from_statement, KustoKQLStatement, + LimitMethod, + RLSMethod, split_kql, SQLGLOT_DIALECTS, SQLScript, @@ -302,7 +305,13 @@ def test_format_no_dialect() -> None: """ assert ( SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "dremio").format() - == "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)" + == """ +SELECT + col +FROM t +WHERE + NOT col IN (1, 2) + """.strip() ) @@ -313,9 +322,9 @@ def test_split_no_dialect() -> None: sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo" statements = SQLScript(sql, "dremio").statements assert len(statements) == 3 - assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)" - assert statements[1]._sql == "SELECT * FROM t" - assert statements[2]._sql == "SELECT foo" + assert statements[0].format() == "SELECT\n col\nFROM t\nWHERE\n NOT col IN (1, 2)" + assert statements[1].format() == "SELECT\n *\nFROM t" + assert statements[2].format() == "SELECT\n foo" def test_extract_tables_show_columns_from() -> None: @@ -742,6 +751,39 @@ Events | take 100""", assert query.get_settings() == {"querytrace": True} +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ( + " SELECT foo FROM tbl ; ", + "postgresql", + ["SELECT\n foo\nFROM tbl"], + ), + ( + "SELECT foo FROM tbl1; SELECT bar FROM tbl2;", + "postgresql", + ["SELECT\n foo\nFROM tbl1", "SELECT\n bar\nFROM tbl2"], + ), + ( + "let foo = 1; tbl | where bar == foo", + "kustokql", + ["let foo = 1", "tbl | where bar == foo"], + ), + ( + "SELECT 1; -- extraneous comment", + "postgresql", + ["SELECT\n 1 /* extraneous comment */"], + ), + ], +) +def test_sqlscript_split(sql: str, engine: str, expected: list[str]) -> None: + """ + Test the `SQLScript` class with a script that has a single statement. + """ + script = SQLScript(sql, engine) + assert [statement.format() for statement in script.statements] == expected + + def test_sqlstatement() -> None: """ Test the `SQLStatement` class. @@ -1085,7 +1127,8 @@ FROM some_table) AS anon_1 WHERE anon_1.a > 1 AND anon_1.b = 2 """ - optimized = """SELECT + optimized = """ +SELECT anon_1.a, anon_1.b FROM ( @@ -1098,18 +1141,23 @@ FROM ( some_table.a > 1 AND some_table.b = 2 ) AS anon_1 WHERE - TRUE AND TRUE""" + TRUE AND TRUE + """.strip() not_optimized = """ -SELECT anon_1.a, - anon_1.b -FROM - (SELECT some_table.a AS a, - some_table.b AS b, - some_table.c AS c - FROM some_table) AS anon_1 -WHERE anon_1.a > 1 - AND anon_1.b = 2""" +SELECT + anon_1.a, + anon_1.b +FROM ( + SELECT + some_table.a AS a, + some_table.b AS b, + some_table.c AS c + FROM some_table +) AS anon_1 +WHERE + anon_1.a > 1 AND anon_1.b = 2 + """.strip() assert SQLStatement(sql, "sqlite").optimize().format() == optimized assert SQLStatement(sql, "dremio").optimize().format() == not_optimized @@ -1160,9 +1208,11 @@ def test_firebolt_old() -> None: sql = "SELECT * FROM t1 UNNEST(col1 AS foo)" assert ( SQLStatement(sql, "firebolt").format() - == """SELECT + == """ +SELECT * -FROM t1 UNNEST(col1 AS foo)""" +FROM t1 UNNEST(col1 AS foo) + """.strip() ) @@ -1181,7 +1231,1157 @@ def test_firebolt_old_escape_string() -> None: # but they normalize to '' assert ( SQLStatement(sql, "firebolt").format() - == """SELECT + == """ +SELECT 'foo''bar', - 'foo''bar'""" + 'foo''bar' + """.strip() ) + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT * FROM users LIMIT 10", "postgresql", 10), + ( + """ +WITH cte_example AS ( + SELECT * FROM my_table + LIMIT 100 +) +SELECT * FROM cte_example +LIMIT 10; + """, + "postgresql", + 10, + ), + ("SELECT * FROM users ORDER BY id DESC LIMIT 25", "postgresql", 25), + ("SELECT * FROM users", "postgresql", None), + ("SELECT TOP 5 name FROM employees", "teradatasql", 5), + ("SELECT TOP (42) * FROM table_name", "teradatasql", 42), + ("select * from table", "postgresql", None), + ("select * from mytable limit 10", "postgresql", 10), + ( + "select * from (select * from my_subquery limit 10) where col=1 limit 20", + "postgresql", + 20, + ), + ("select * from (select * from my_subquery limit 10);", "postgresql", None), + ( + "select * from (select * from my_subquery limit 10) where col=1 limit 20;", + "postgresql", + 20, + ), + ("select * from mytable limit 20, 10", "postgresql", 10), + ("select * from mytable limit 10 offset 20", "postgresql", 10), + ( + """ +SELECT id, value, i +FROM (SELECT * FROM my_table LIMIT 10), +LATERAL generate_series(1, value) AS i; + """, + "postgresql", + None, + ), + ], +) +def test_get_limit_value(sql: str, engine: str, expected: str) -> None: + assert SQLStatement(sql, engine).get_limit_value() == expected + + +@pytest.mark.parametrize( + "kql, expected", + [ + ("StormEvents | take 10", 10), + ("StormEvents | limit 20", 20), + ("StormEvents | where State == 'FL' | summarize count()", None), + ("StormEvents | where name has 'limit 10'", None), + ("AnotherTable | take 5", 5), + ("datatable(x:int) [1, 2, 3] | take 100", 100), + ( + """ + Table1 | where msg contains 'abc;xyz' + | limit 5 + """, + 5, + ), + ], +) +def test_get_kql_limit_value(kql: str, expected: str) -> None: + assert KustoKQLStatement(kql, "kustokql").get_limit_value() == expected + + +@pytest.mark.parametrize( + "sql, engine, limit, method, expected", + [ + ( + "SELECT * FROM t", + "postgresql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM t\nLIMIT 10", + ), + ( + "SELECT * FROM t LIMIT 1000", + "postgresql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM t\nLIMIT 10", + ), + ( + "SELECT * FROM t", + "mssql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10\n *\nFROM t", + ), + ( + "SELECT * FROM t", + "teradatasql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10\n *\nFROM t", + ), + ( + "SELECT * FROM t", + "oracle", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM t\nFETCH FIRST 10 ROWS ONLY", + ), + ( + "SELECT * FROM t", + "db2", + 10, + LimitMethod.WRAP_SQL, + "SELECT\n *\nFROM (\n SELECT\n *\n FROM t\n)\nLIMIT 10", + ), + ( + "SEL TOP 1000 * FROM My_table", + "teradatasql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 100\n *\nFROM My_table", + ), + ( + "SEL TOP 1000 * FROM My_table;", + "teradatasql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 100\n *\nFROM My_table", + ), + ( + "SEL TOP 1000 * FROM My_table;", + "teradatasql", + 1000, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 1000\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table;", + "teradatasql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 100\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table;", + "teradatasql", + 10000, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10000\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table", + "mssql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 100\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table;", + "mssql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 100\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table;", + "mssql", + 10000, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10000\n *\nFROM My_table", + ), + ( + "SELECT TOP 1000 * FROM My_table;", + "mssql", + 1000, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 1000\n *\nFROM My_table", + ), + ( + """ +with abc as (select * from test union select * from test1) +select TOP 100 * from currency + """, + "mssql", + 1000, + LimitMethod.FORCE_LIMIT, + """ +WITH abc AS ( + SELECT + * + FROM test + UNION + SELECT + * + FROM test1 +) +SELECT +TOP 1000 + * +FROM currency + """.strip(), + ), + ( + "SELECT DISTINCT x from tbl", + "mssql", + 100, + LimitMethod.FORCE_LIMIT, + "SELECT DISTINCT\nTOP 100\n x\nFROM tbl", + ), + ( + "SELECT 1 as cnt", + "mssql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10\n 1 AS cnt", + ), + ( + "select TOP 1000 * from abc where id=1", + "mssql", + 10, + LimitMethod.FORCE_LIMIT, + "SELECT\nTOP 10\n *\nFROM abc\nWHERE\n id = 1", + ), + ( + "SELECT * FROM birth_names -- SOME COMMENT", + "postgresql", + 1000, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM birth_names /* SOME COMMENT */\nLIMIT 1000", + ), + ( + "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555", + "postgresql", + 1000, + LimitMethod.FORCE_LIMIT, + """ +SELECT + * +FROM birth_names /* SOME COMMENT WITH LIMIT 555 */ +LIMIT 1000 + """.strip(), + ), + ( + "SELECT * FROM birth_names LIMIT 555", + "postgresql", + 1000, + LimitMethod.FORCE_LIMIT, + "SELECT\n *\nFROM birth_names\nLIMIT 1000", + ), + ], +) +def test_set_limit_value( + sql: str, + engine: str, + limit: int, + method: LimitMethod, + expected: str, +) -> None: + statement = SQLStatement(sql, engine) + statement.set_limit_value(limit, method) + assert statement.format() == expected + + +@pytest.mark.parametrize( + "kql, limit, expected", + [ + ("StormEvents | take 10", 100, "StormEvents | take 100"), + ("StormEvents | limit 20", 10, "StormEvents | limit 10"), + ( + "StormEvents | where State == 'FL' | summarize count()", + 10, + "StormEvents | where State == 'FL' | summarize count() | take 10", + ), + ( + "StormEvents | where name has 'limit 10'", + 10, + "StormEvents | where name has 'limit 10' | take 10", + ), + ("AnotherTable | take 5", 50, "AnotherTable | take 50"), + ( + "datatable(x:int) [1, 2, 3] | take 100", + 10, + "datatable(x:int) [1, 2, 3] | take 10", + ), + ( + """ + Table1 | where msg contains 'abc;xyz' + | limit 5 + """, + 10, + """Table1 | where msg contains 'abc;xyz' + | limit 10""", + ), + ], +) +def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None: + statement = KustoKQLStatement(kql, "kustokql") + statement.set_limit_value(limit) + assert statement.format() == expected + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT 1", "postgresql", False), + ("SELECT 1 AS cnt", "postgresql", False), + ( + """ +SELECT 'INR' AS cur +UNION +SELECT 'USD' AS cur +UNION +SELECT 'EUR' AS cur + """, + "postgresql", + False, + ), + ("WITH cte AS (SELECT 1) SELECT * FROM cte", "postgresql", True), + ( + """ +WITH + x AS (SELECT a FROM t1), + y AS (SELECT a AS b FROM t2), + z AS (SELECT b AS c FROM t3) +SELECT c FROM z + """, + "postgresql", + True, + ), + ( + """ +WITH + x AS (SELECT a FROM t1), + y AS (SELECT a AS b FROM x), + z AS (SELECT b AS c FROM y) +SELECT c FROM z + """, + "postgresql", + True, + ), + ( + """ +WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear) +AS ( + SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear + FROM SalesOrderHeader + WHERE SalesPersonID IS NOT NULL +) +SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear +FROM CTE__test +GROUP BY SalesYear, SalesPersonID +ORDER BY SalesPersonID, SalesYear; + """, + "postgresql", + True, + ), + ], +) +def test_has_cte(sql: str, engine: str, expected: bool) -> None: + """ + Test that the parser detects CTEs correctly. + """ + assert SQLStatement(sql, engine).has_cte() == expected + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ( + "SELECT 1", + "postgresql", + "WITH __cte AS (\n SELECT\n 1\n)", + ), + ( + """ +WITH currency AS (SELECT 'INR' AS cur), + currency_2 AS (SELECT 'USD' AS cur) +SELECT * FROM currency +UNION ALL +SELECT * FROM currency_2 + """, + "postgresql", + """ +WITH currency AS ( + SELECT + 'INR' AS cur +), currency_2 AS ( + SELECT + 'USD' AS cur +), __cte AS ( + SELECT + * + FROM currency + UNION ALL + SELECT + * + FROM currency_2 +) + """.strip(), + ), + ], +) +def test_as_cte(sql: str, engine: str, expected: str) -> None: + """ + Test that we can covert select to CTE. + """ + assert SQLStatement(sql, engine).as_cte().format() == expected + + +@pytest.mark.parametrize( + "sql, rules, expected", + [ + ( + "SELECT t.foo FROM some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM ( + SELECT + * + FROM some_table + WHERE + id = 42 +) AS t + """.strip(), + ), + ( + "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM ( + SELECT + * + FROM some_table + WHERE + id = 42 +) AS t +WHERE + bar = 'baz' + """.strip(), + ), + ( + "SELECT t.foo FROM schema1.some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM ( + SELECT + * + FROM schema1.some_table + WHERE + id = 42 +) AS t + """.strip(), + ), + ( + "SELECT t.foo FROM schema1.some_table AS t", + {Table("some_table", "schema2"): "id = 42"}, + "SELECT\n t.foo\nFROM schema1.some_table AS t", + ), + ( + "SELECT t.foo FROM catalog1.schema1.some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM ( + SELECT + * + FROM catalog1.schema1.some_table + WHERE + id = 42 +) AS t + """.strip(), + ), + ( + "SELECT t.foo FROM catalog1.schema1.some_table AS t", + {Table("some_table", "schema1", "catalog2"): "id = 42"}, + "SELECT\n t.foo\nFROM catalog1.schema1.some_table AS t", + ), + ( + "SELECT * FROM some_table WHERE 1=1", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM some_table + WHERE + id = 42 +) AS some_table +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM table + WHERE + id = 42 +) AS table +WHERE + 1 = 1 + """.strip(), + ), + ( + 'SELECT * FROM "table" WHERE 1=1', + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM "table" + WHERE + id = 42 +) AS "table" +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM other_table WHERE 1=1", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM other_table +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM table JOIN other_table ON table.id = other_table.id", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +JOIN ( + SELECT + * + FROM other_table + WHERE + id = 42 +) AS other_table + ON table.id = other_table.id + """.strip(), + ), + ( + 'SELECT * FROM "table" JOIN other_table ON "table".id = other_table.id', + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM "table" + WHERE + id = 42 +) AS "table" +JOIN other_table + ON "table".id = other_table.id + """.strip(), + ), + ( + "SELECT * FROM (SELECT * FROM some_table)", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM ( + SELECT + * + FROM some_table + WHERE + id = 42 + ) AS some_table +) + """.strip(), + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM table + WHERE + id = 42 +) AS table +UNION ALL +SELECT + * +FROM other_table + """.strip(), + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +UNION ALL +SELECT + * +FROM ( + SELECT + * + FROM other_table + WHERE + id = 42 +) AS other_table + """.strip(), + ), + ( + "SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = b.col", + {Table("tbl_a", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + a.*, + b.* +FROM ( + SELECT + * + FROM tbl_a + WHERE + id = 42 +) AS a +INNER JOIN tbl_b AS b + ON a.col = b.col + """.strip(), + ), + ( + "SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col", + {Table("tbl_a", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + a.*, + b.* +FROM ( + SELECT + * + FROM tbl_a + WHERE + id = 42 +) AS a +INNER JOIN tbl_b AS b + ON a.col = b.col + """.strip(), + ), + ], +) +def test_rls_subquery_transformer( + sql: str, + rules: dict[Table, str], + expected: str, +) -> None: + """ + Test `RLSAsSubqueryTransformer`. + """ + statement = SQLStatement(sql) + statement.apply_rls( + "catalog1", + "schema1", + {k: [parse_one(v)] for k, v in rules.items()}, + RLSMethod.AS_SUBQUERY, + ) + assert statement.format() == expected + + +@pytest.mark.parametrize( + "sql, rules, expected", + [ + ( + "SELECT t.foo FROM some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM some_table AS t +WHERE + t.id = 42 + """.strip(), + ), + ( + "SELECT t.foo FROM schema2.some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM schema2.some_table AS t + """.strip(), + ), + ( + "SELECT t.foo FROM catalog2.schema1.some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM catalog2.schema1.some_table AS t + """.strip(), + ), + ( + "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM some_table AS t +WHERE + t.id = 42 AND ( + bar = 'baz' + ) + """.strip(), + ), + ( + "SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + t.foo +FROM some_table AS t +WHERE + t.id = 42 AND ( + bar = 'baz' OR foo = 'qux' + ) + """.strip(), + ), + ( + "SELECT * FROM some_table WHERE 1=1", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM some_table +WHERE + some_table.id = 42 AND ( + 1 = 1 + ) + """.strip(), + ), + ( + "SELECT * FROM some_table WHERE TRUE OR FALSE", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM some_table +WHERE + some_table.id = 42 AND ( + TRUE OR FALSE + ) + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 AND ( + 1 = 1 + ) + """.strip(), + ), + ( + 'SELECT * FROM "table" WHERE 1=1', + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM "table" +WHERE + "table".id = 42 AND ( + 1 = 1 + ) + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM other_table WHERE 1=1", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM other_table +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM table", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 + """.strip(), + ), + ( + "SELECT * FROM some_table", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM some_table +WHERE + some_table.id = 42 + """.strip(), + ), + ( + "SELECT * FROM table ORDER BY id", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 +ORDER BY + id + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1 AND table.id=42", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 AND ( + 1 = 1 AND table.id = 42 + ) + """.strip(), + ), + ( + """ +SELECT * FROM table +JOIN other_table +ON table.id = other_table.id +AND other_table.id=42 + """, + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +JOIN other_table + ON other_table.id = 42 AND ( + table.id = other_table.id AND other_table.id = 42 + ) + """.strip(), + ), + ( + "SELECT * FROM table WHERE 1=1 AND id=42", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 AND ( + 1 = 1 AND id = 42 + ) + """.strip(), + ), + ( + "SELECT * FROM table JOIN other_table ON table.id = other_table.id", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +JOIN other_table + ON other_table.id = 42 AND ( + table.id = other_table.id + ) + """.strip(), + ), + ( + """ +SELECT * +FROM table +JOIN other_table +ON table.id = other_table.id +WHERE 1=1 + """, + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +JOIN other_table + ON other_table.id = 42 AND ( + table.id = other_table.id + ) +WHERE + 1 = 1 + """.strip(), + ), + ( + "SELECT * FROM (SELECT * FROM other_table)", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM ( + SELECT + * + FROM other_table + WHERE + other_table.id = 42 +) + """.strip(), + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + {Table("table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +WHERE + table.id = 42 +UNION ALL +SELECT + * +FROM other_table + """.strip(), + ), + ( + "SELECT * FROM table UNION ALL SELECT * FROM other_table", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +UNION ALL +SELECT + * +FROM other_table +WHERE + other_table.id = 42 + """.strip(), + ), + ], +) +def test_rls_predicate_transformer( + sql: str, + rules: dict[Table, str], + expected: str, +) -> None: + """ + Test `RLSPredicateTransformer`. + """ + statement = SQLStatement(sql) + statement.apply_rls( + "catalog1", + "schema1", + {k: [parse_one(v)] for k, v in rules.items()}, + RLSMethod.AS_PREDICATE, + ) + assert statement.format() == expected + + +@pytest.mark.parametrize( + "sql, table, expected", + [ + ( + "SELECT * FROM some_table", + Table("some_table"), + """ +CREATE TABLE some_table AS +SELECT + * +FROM some_table + """.strip(), + ), + ( + "SELECT * FROM some_table", + Table("some_table", "schema1", "catalog1"), + """ +CREATE TABLE catalog1.schema1.some_table AS +SELECT + * +FROM some_table + """.strip(), + ), + ], +) +def test_as_create_table(sql: str, table: Table, expected: str) -> None: + """ + Test the `as_create_table` method. + """ + statement = SQLStatement(sql) + create_table = statement.as_create_table(table, CTASMethod.TABLE) + assert create_table.format() == expected + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT * FROM table", "postgresql", True), + ( + """ +-- comment +SELECT * FROM table +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +-- comment +SET @value = 42; +SELECT @value as foo; +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +-- comment +EXPLAIN SELECT * FROM table +-- comment 2 + """, + "mysql", + False, + ), + ( + """ +SELECT * FROM table; +INSERT INTO TABLE (foo) VALUES (42); + """, + "mysql", + False, + ), + ], +) +def test_is_valid_ctas(sql: str, engine: str, expected: bool) -> None: + """ + Test the `is_valid_ctas` method. + """ + assert SQLScript(sql, engine).is_valid_ctas() == expected + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT * FROM table", "postgresql", True), + ( + """ +-- comment +SELECT * FROM table +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +-- comment +SET @value = 42; +SELECT @value as foo; +-- comment 2 + """, + "mysql", + False, + ), + ( + """ +-- comment +SELECT value as foo; +-- comment 2 + """, + "mysql", + True, + ), + ( + """ +SELECT * FROM table; +INSERT INTO TABLE (foo) VALUES (42); + """, + "mysql", + False, + ), + ], +) +def test_is_valid_cvas(sql: str, engine: str, expected: bool) -> None: + """ + Test the `is_valid_cvas` method. + """ + assert SQLScript(sql, engine).is_valid_cvas() == expected diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 21c9a95247b..1bb2cde5f02 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -21,107 +21,46 @@ from unittest import mock from uuid import UUID import pytest -import sqlparse from freezegun import freeze_time from pytest_mock import MockerFixture -from sqlalchemy.orm.session import Session -from superset import db from superset.common.db_query_status import QueryStatus +from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import OAuth2Error, SupersetErrorException from superset.models.core import Database -from superset.sql_lab import execute_sql_statements, get_sql_results -from superset.utils.core import override_user +from superset.sql.parse import SQLStatement, Table +from superset.sql_lab import ( + apply_rls, + execute_query, + execute_sql_statements, + get_predicates_for_table, + get_sql_results, +) from tests.unit_tests.models.core_test import oauth2_client_info -def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: +def test_execute_query(mocker: MockerFixture, app: None) -> None: """ Simple test for `execute_sql_statement`. """ - from superset.sql_lab import execute_sql_statement - - sql_statement = "SELECT 42 AS answer" - query = mocker.MagicMock() + query.executed_sql = "SELECT 42 AS answer" + query.limit = 1 - query.select_as_cta_used = False database = query.database database.allow_dml = False - database.apply_limit_to_sql.return_value = "SELECT 42 AS answer LIMIT 2" - database.mutate_sql_based_on_config.return_value = "SELECT 42 AS answer LIMIT 2" db_engine_spec = database.db_engine_spec - db_engine_spec.is_select_query.return_value = True db_engine_spec.fetch_data.return_value = [(42,)] cursor = mocker.MagicMock() SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") # noqa: N806 - execute_sql_statement( - sql_statement, - query, - cursor=cursor, - log_params={}, - apply_ctas=False, - ) + execute_query(query, cursor=cursor, log_params={}) - database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True) db_engine_spec.execute_with_cursor.assert_called_with( cursor, - "SELECT 42 AS answer LIMIT 2", - query, - ) - SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) - - -def test_execute_sql_statement_with_rls( - mocker: MockerFixture, -) -> None: - """ - Test for `execute_sql_statement` when an RLS rule is in place. - """ - from superset.sql_lab import execute_sql_statement - - sql_statement = "SELECT * FROM sales" - sql_statement_with_rls = f"{sql_statement} WHERE organization_id=42" - sql_statement_with_rls_and_limit = f"{sql_statement_with_rls} LIMIT 101" - - query = mocker.MagicMock() - query.limit = 100 - query.select_as_cta_used = False - database = query.database - database.allow_dml = False - database.apply_limit_to_sql.return_value = sql_statement_with_rls_and_limit - database.mutate_sql_based_on_config.return_value = sql_statement_with_rls_and_limit - db_engine_spec = database.db_engine_spec - db_engine_spec.is_select_query.return_value = True - db_engine_spec.fetch_data.return_value = [(42,)] - - cursor = mocker.MagicMock() - SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet") # noqa: N806 - mocker.patch( - "superset.sql_lab.insert_rls_as_subquery", - return_value=sqlparse.parse("SELECT * FROM sales WHERE organization_id=42")[0], - ) - mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) - - execute_sql_statement( - sql_statement, - query, - cursor=cursor, - log_params={}, - apply_ctas=False, - ) - - database.apply_limit_to_sql.assert_called_with( - "SELECT * FROM sales WHERE organization_id=42", - 101, - force=True, - ) - db_engine_spec.execute_with_cursor.assert_called_with( - cursor, - "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", + "SELECT 42 AS answer", query, ) SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec) @@ -140,7 +79,6 @@ def test_execute_sql_statement_exceeds_payload_limit(mocker: MockerFixture) -> N query = mocker.MagicMock() query.limit = 1 query.database = mocker.MagicMock() - query.database.db_engine_spec.is_select_query.return_value = True query.database.cache_timeout = 100 query.status = "RUNNING" query.select_as_cta = False @@ -193,7 +131,6 @@ def test_execute_sql_statement_within_payload_limit(mocker: MockerFixture) -> No query = mocker.MagicMock() query.limit = 1 query.database = mocker.MagicMock() - query.database.db_engine_spec.is_select_query.return_value = True query.database.cache_timeout = 100 query.status = "RUNNING" query.select_as_cta = False @@ -236,111 +173,6 @@ def test_execute_sql_statement_within_payload_limit(mocker: MockerFixture) -> No ) -def test_sql_lab_insert_rls_as_subquery( - mocker: MockerFixture, - session: Session, -) -> None: - """ - Integration test for `insert_rls_as_subquery`. - """ - from flask_appbuilder.security.sqla.models import Role, User - - from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable - from superset.models.core import Database - from superset.models.sql_lab import Query - from superset.security.manager import SupersetSecurityManager - from superset.sql_lab import execute_sql_statement - from superset.utils.core import RowLevelSecurityFilterType - - engine = db.session.connection().engine - Query.metadata.create_all(engine) # pylint: disable=no-member - - connection = engine.raw_connection() - connection.execute("CREATE TABLE t (c INTEGER)") - for i in range(10): - connection.execute("INSERT INTO t VALUES (?)", (i,)) - - cursor = connection.cursor() - - query = Query( - sql="SELECT c FROM t", - client_id="abcde", - database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"), - schema=None, - limit=5, - select_as_cta_used=False, - ) - db.session.add(query) - db.session.commit() - - admin = User( - first_name="Alice", - last_name="Doe", - email="adoe@example.org", - username="admin", - roles=[Role(name="Admin")], - ) - - # first without RLS - with override_user(admin): - superset_result_set = execute_sql_statement( - sql_statement=query.sql, - query=query, - cursor=cursor, - log_params=None, - apply_ctas=False, - ) - assert ( - superset_result_set.to_pandas_df().to_markdown() - == """ -| | c | -|---:|----:| -| 0 | 0 | -| 1 | 1 | -| 2 | 2 | -| 3 | 3 | -| 4 | 4 |""".strip() - ) - assert query.executed_sql == "SELECT c FROM t\nLIMIT 6" - - # now with RLS - rls = RowLevelSecurityFilter( - name="sqllab_rls1", - filter_type=RowLevelSecurityFilterType.REGULAR, - tables=[SqlaTable(database_id=1, schema=None, table_name="t")], - roles=[admin.roles[0]], - group_key=None, - clause="c > 5", - ) - db.session.add(rls) - db.session.flush() - mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin) - mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) - - with override_user(admin): - superset_result_set = execute_sql_statement( - sql_statement=query.sql, - query=query, - cursor=cursor, - log_params=None, - apply_ctas=False, - ) - assert ( - superset_result_set.to_pandas_df().to_markdown() - == """ -| | c | -|---:|----:| -| 0 | 6 | -| 1 | 7 | -| 2 | 8 | -| 3 | 9 |""".strip() - ) - assert ( - query.executed_sql - == "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6" - ) - - @freeze_time("2021-04-01T00:00:00Z") def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: """ @@ -370,8 +202,7 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: "OAuth2 required" ) - query = mocker.MagicMock() - query.database = database + query = mocker.MagicMock(select_as_cta=False, database=database) mocker.patch("superset.sql_lab.get_query", return_value=query) payload = get_sql_results(query_id=1, rendered_query="SELECT 1") @@ -391,3 +222,66 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: } ], } + + +def test_apply_rls(mocker: MockerFixture) -> None: + """ + Test the ``apply_rls`` helper function. + """ + database = mocker.MagicMock() + database.get_default_schema_for_query.return_value = "public" + database.get_default_catalog.return_value = "examples" + database.db_engine_spec = PostgresEngineSpec + get_predicates_for_table = mocker.patch( + "superset.sql_lab.get_predicates_for_table", + side_effect=[["c1 = 1"], ["c2 = 2"]], + ) + + parsed_statement = SQLStatement("SELECT * FROM t1, t2", "postgresql") + parsed_statement.tables = sorted(parsed_statement.tables, key=lambda x: x.table) # type: ignore + + apply_rls(database, "examples", "public", parsed_statement) + + get_predicates_for_table.assert_has_calls( + [ + mocker.call(Table("t1", "public", "examples"), database, "examples"), + mocker.call(Table("t2", "public", "examples"), database, "examples"), + ] + ) + + assert ( + parsed_statement.format() + == """ +SELECT + * +FROM ( + SELECT + * + FROM t1 + WHERE + c1 = 1 +) AS t1, ( + SELECT + * + FROM t2 + WHERE + c2 = 2 +) AS t2 + """.strip() + ) + + +def test_get_predicates_for_table(mocker: MockerFixture) -> None: + """ + Test the ``get_predicates_for_table`` helper function. + """ + database = mocker.MagicMock() + dataset = mocker.MagicMock() + predicate = mocker.MagicMock() + predicate.compile.return_value = "c1 = 1" + dataset.get_sqla_row_level_filters.return_value = [predicate] + db = mocker.patch("superset.sql_lab.db") + db.session.query().filter().one_or_none.return_value = dataset + + table = Table("t1", "public", "examples") + assert get_predicates_for_table(table, database, "examples") == ["c1 = 1"] diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 23aa6b0b125..f0c55e4a267 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1104,46 +1104,6 @@ def test_unknown_select() -> None: assert not ParsedQuery(sql).is_select() -def test_get_query_with_new_limit_comment() -> None: - """ - Test that limit is applied correctly. - """ - query = ParsedQuery("SELECT * FROM birth_names -- SOME COMMENT") - assert query.set_or_update_query_limit(1000) == ( - "SELECT * FROM birth_names -- SOME COMMENT\nLIMIT 1000" - ) - - -def test_get_query_with_new_limit_comment_with_limit() -> None: - """ - Test that limits in comments are ignored. - """ - query = ParsedQuery("SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555") - assert query.set_or_update_query_limit(1000) == ( - "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555\nLIMIT 1000" - ) - - -def test_get_query_with_new_limit_lower() -> None: - """ - Test that lower limits are not replaced. - """ - query = ParsedQuery("SELECT * FROM birth_names LIMIT 555") - assert query.set_or_update_query_limit(1000) == ( - "SELECT * FROM birth_names LIMIT 555" - ) - - -def test_get_query_with_new_limit_upper() -> None: - """ - Test that higher limits are replaced. - """ - query = ParsedQuery("SELECT * FROM birth_names LIMIT 2000") - assert query.set_or_update_query_limit(1000) == ( - "SELECT * FROM birth_names LIMIT 1000" - ) - - def test_basic_breakdown_statements() -> None: """ Test that multiple statements are parsed correctly.