mirror of
https://github.com/apache/superset.git
synced 2026-04-28 20:44:24 +00:00
Compare commits
10 Commits
request-co
...
rls-sqlglo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e7b9f84aa | ||
|
|
5231e86b6c | ||
|
|
6b73b69b41 | ||
|
|
35da6ac270 | ||
|
|
d5c0506faa | ||
|
|
fb1f2c4f18 | ||
|
|
7f3c8efab0 | ||
|
|
7bb7fc0f49 | ||
|
|
3449b8f9dc | ||
|
|
7a0ae36c4a |
6
.github/workflows/superset-e2e.yml
vendored
6
.github/workflows/superset-e2e.yml
vendored
@@ -66,20 +66,20 @@ jobs:
|
||||
# Conditional checkout based on context
|
||||
- name: Checkout for push or pull_request event
|
||||
if: github.event_name == 'push' || github.event_name == 'pull_request'
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
submodules: recursive
|
||||
- name: Checkout using ref (workflow_dispatch)
|
||||
if: github.event_name == 'workflow_dispatch' && github.event.inputs.ref != ''
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
ref: ${{ github.event.inputs.ref }}
|
||||
submodules: recursive
|
||||
- name: Checkout using PR ID (workflow_dispatch)
|
||||
if: github.event_name == 'workflow_dispatch' && github.event.inputs.pr_id != ''
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
ref: refs/pull/${{ github.event.inputs.pr_id }}/merge
|
||||
|
||||
@@ -214,13 +214,14 @@ connections from the Docker involves making one-line changes to the files `postg
|
||||
`pg_hba.conf`; you can find helpful links tailored to your OS / PG version on the web easily for
|
||||
this task. For Docker it suffices to only whitelist IPs `172.0.0.0/8` instead of `*`, but in any
|
||||
case you are _warned_ that doing this in a production database _may_ have disastrous consequences as
|
||||
you are opening your database to the public internet. 2. Instead of `localhost`, try using
|
||||
`host.docker.internal` (Mac users, Ubuntu) or `172.18.0.1` (Linux users) as the hostname when
|
||||
attempting to connect to the database. This is a Docker internal detail -- what is happening is
|
||||
that, in Mac systems, Docker Desktop creates a dns entry for the hostname `host.docker.internal`
|
||||
which resolves to the correct address for the host machine, whereas in Linux this is not the case
|
||||
(at least by default). If neither of these 2 hostnames work then you may want to find the exact
|
||||
hostname you want to use, for that you can do `ifconfig` or `ip addr show` and look at the IP
|
||||
address of `docker0` interface that must have been created by Docker for you. Alternately if you
|
||||
don't even see the `docker0` interface try (if needed with sudo) `docker network inspect bridge` and
|
||||
see if there is an entry for `"Gateway"` and note the IP address.
|
||||
you are opening your database to the public internet.
|
||||
1. Instead of `localhost`, try using `host.docker.internal` (Mac users, Ubuntu) or `172.18.0.1`
|
||||
(Linux users) as the hostname when attempting to connect to the database. This is a Docker internal
|
||||
detail -- what is happening is that, in Mac systems, Docker Desktop creates a dns entry for the
|
||||
hostname `host.docker.internal` which resolves to the correct address for the host machine, whereas
|
||||
in Linux this is not the case (at least by default). If neither of these 2 hostnames work then you
|
||||
may want to find the exact hostname you want to use, for that you can do `ifconfig` or
|
||||
`ip addr show` and look at the IP address of `docker0` interface that must have been created by
|
||||
Docker for you. Alternately if you don't even see the `docker0` interface try (if needed with sudo)
|
||||
`docker network inspect bridge` and see if there is an entry for `"Gateway"` and note the IP
|
||||
address.
|
||||
|
||||
16
superset-frontend/package-lock.json
generated
16
superset-frontend/package-lock.json
generated
@@ -70520,7 +70520,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/jest": "^29.5.12",
|
||||
"@types/lodash": "^4.17.4",
|
||||
"@types/lodash": "^4.17.6",
|
||||
"jest": "^29.7.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
@@ -70535,9 +70535,9 @@
|
||||
}
|
||||
},
|
||||
"plugins/plugin-chart-handlebars/node_modules/@types/lodash": {
|
||||
"version": "4.17.4",
|
||||
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.4.tgz",
|
||||
"integrity": "sha512-wYCP26ZLxaT3R39kiN2+HcJ4kTd3U1waI/cY7ivWYqFP6pW3ZNpvi6Wd6PHZx7T/t8z0vlkXMg3QYLa7DZ/IJQ==",
|
||||
"version": "4.17.6",
|
||||
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.6.tgz",
|
||||
"integrity": "sha512-OpXEVoCKSS3lQqjx9GGGOapBeuW5eUboYHRlHP9urXPX25IKZ6AnP5ZRxtVf63iieUbsHxLn8NQ5Nlftc6yzAA==",
|
||||
"dev": true
|
||||
},
|
||||
"plugins/plugin-chart-handlebars/node_modules/just-handlebars-helpers": {
|
||||
@@ -89053,16 +89053,16 @@
|
||||
"version": "file:plugins/plugin-chart-handlebars",
|
||||
"requires": {
|
||||
"@types/jest": "^29.5.12",
|
||||
"@types/lodash": "^4.17.4",
|
||||
"@types/lodash": "^4.17.6",
|
||||
"handlebars": "^4.7.7",
|
||||
"jest": "^29.7.0",
|
||||
"just-handlebars-helpers": "^1.0.19"
|
||||
},
|
||||
"dependencies": {
|
||||
"@types/lodash": {
|
||||
"version": "4.17.4",
|
||||
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.4.tgz",
|
||||
"integrity": "sha512-wYCP26ZLxaT3R39kiN2+HcJ4kTd3U1waI/cY7ivWYqFP6pW3ZNpvi6Wd6PHZx7T/t8z0vlkXMg3QYLa7DZ/IJQ==",
|
||||
"version": "4.17.6",
|
||||
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.6.tgz",
|
||||
"integrity": "sha512-OpXEVoCKSS3lQqjx9GGGOapBeuW5eUboYHRlHP9urXPX25IKZ6AnP5ZRxtVf63iieUbsHxLn8NQ5Nlftc6yzAA==",
|
||||
"dev": true
|
||||
},
|
||||
"just-handlebars-helpers": {
|
||||
|
||||
@@ -31,7 +31,7 @@
|
||||
"d3-array": "^1.2.4",
|
||||
"d3-color": "^1.4.1",
|
||||
"d3-scale": "^3.0.0",
|
||||
"deck.gl": "9.0.12",
|
||||
"deck.gl": "9.0.20",
|
||||
"lodash": "^4.17.21",
|
||||
"moment": "^2.30.1",
|
||||
"mousetrap": "^1.6.5",
|
||||
|
||||
@@ -42,7 +42,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/jest": "^29.5.12",
|
||||
"@types/lodash": "^4.17.4",
|
||||
"@types/lodash": "^4.17.6",
|
||||
"jest": "^29.7.0"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ const StyledTitle = styled.span`
|
||||
interface BasicErrorAlertProps {
|
||||
title: string;
|
||||
body: string;
|
||||
level: ErrorLevel;
|
||||
level?: ErrorLevel;
|
||||
}
|
||||
|
||||
export default function BasicErrorAlert({
|
||||
|
||||
@@ -76,6 +76,7 @@ import {
|
||||
OPEN_FILTER_BAR_WIDTH,
|
||||
EMPTY_CONTAINER_Z_INDEX,
|
||||
} from 'src/dashboard/constants';
|
||||
import BasicErrorAlert from 'src/components/ErrorMessage/BasicErrorAlert';
|
||||
import { getRootLevelTabsComponent, shouldFocusTabs } from './utils';
|
||||
import DashboardContainer from './DashboardContainer';
|
||||
import { useNativeFilters } from './state';
|
||||
@@ -462,6 +463,7 @@ const DashboardBuilder: FC<DashboardBuilderProps> = () => {
|
||||
|
||||
const {
|
||||
showDashboard,
|
||||
missingInitialFilters,
|
||||
dashboardFiltersOpen,
|
||||
toggleDashboardFiltersOpen,
|
||||
nativeFiltersEnabled,
|
||||
@@ -673,7 +675,30 @@ const DashboardBuilder: FC<DashboardBuilderProps> = () => {
|
||||
editMode={editMode}
|
||||
marginLeft={dashboardContentMarginLeft}
|
||||
>
|
||||
{showDashboard ? (
|
||||
{missingInitialFilters.length > 0 ? (
|
||||
<div
|
||||
css={css`
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
flex: 1;
|
||||
& div {
|
||||
width: 500px;
|
||||
}
|
||||
`}
|
||||
>
|
||||
<BasicErrorAlert
|
||||
title={t('Unable to load dashboard')}
|
||||
body={t(
|
||||
`The following filters have the 'Select first filter value by default'
|
||||
option checked and could not be loaded, which is preventing the dashboard
|
||||
from rendering: %s`,
|
||||
missingInitialFilters.join(', '),
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
) : showDashboard ? (
|
||||
<DashboardContainer topLevelTabs={topLevelTabs} />
|
||||
) : (
|
||||
<Loading />
|
||||
|
||||
@@ -47,17 +47,14 @@ export const useNativeFilters = () => {
|
||||
filter => filter.requiredFirst,
|
||||
);
|
||||
const dataMask = useNativeFiltersDataMask();
|
||||
|
||||
const missingInitialFilters = requiredFirstFilter
|
||||
.filter(({ id }) => dataMask[id]?.filterState?.value === undefined)
|
||||
.map(({ name }) => name);
|
||||
const showDashboard =
|
||||
isInitialized ||
|
||||
!nativeFiltersEnabled ||
|
||||
!(
|
||||
nativeFiltersEnabled &&
|
||||
requiredFirstFilter.length &&
|
||||
requiredFirstFilter.find(
|
||||
({ id }) => dataMask[id]?.filterState?.value === undefined,
|
||||
)
|
||||
);
|
||||
|
||||
missingInitialFilters.length === 0;
|
||||
const toggleDashboardFiltersOpen = useCallback(
|
||||
(visible?: boolean) => {
|
||||
setDashboardFiltersOpen(visible ?? !dashboardFiltersOpen);
|
||||
@@ -84,6 +81,7 @@ export const useNativeFilters = () => {
|
||||
|
||||
return {
|
||||
showDashboard,
|
||||
missingInitialFilters,
|
||||
dashboardFiltersOpen,
|
||||
toggleDashboardFiltersOpen,
|
||||
nativeFiltersEnabled,
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
} from 'spec/helpers/testing-library';
|
||||
import { Comparator } from '@superset-ui/chart-controls';
|
||||
import { ColorSchemeEnum } from '@superset-ui/plugin-chart-table';
|
||||
import { FormattingPopoverContent } from './FormattingPopoverContent';
|
||||
|
||||
const mockOnChange = jest.fn();
|
||||
|
||||
const columns = [
|
||||
{ label: 'Column 1', value: 'column1' },
|
||||
{ label: 'Column 2', value: 'column2' },
|
||||
];
|
||||
|
||||
const extraColorChoices = [
|
||||
{
|
||||
value: ColorSchemeEnum.Green,
|
||||
label: 'Green for increase, red for decrease',
|
||||
},
|
||||
{
|
||||
value: ColorSchemeEnum.Red,
|
||||
label: 'Red for increase, green for decrease',
|
||||
},
|
||||
];
|
||||
|
||||
test('renders FormattingPopoverContent component', () => {
|
||||
render(
|
||||
<FormattingPopoverContent
|
||||
onChange={mockOnChange}
|
||||
columns={columns}
|
||||
extraColorChoices={extraColorChoices}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert that the component renders correctly
|
||||
expect(screen.getByLabelText('Column')).toBeInTheDocument();
|
||||
expect(screen.getAllByLabelText('Color scheme')).toHaveLength(2);
|
||||
expect(screen.getAllByLabelText('Operator')).toHaveLength(2);
|
||||
expect(screen.queryByLabelText('Left value')).not.toBeInTheDocument();
|
||||
expect(screen.queryByLabelText('Right value')).not.toBeInTheDocument();
|
||||
expect(screen.getByText('Apply')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('calls onChange when Apply button is clicked', async () => {
|
||||
render(
|
||||
<FormattingPopoverContent
|
||||
onChange={mockOnChange}
|
||||
columns={columns}
|
||||
extraColorChoices={extraColorChoices}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Simulate user interaction by clicking the Apply button
|
||||
fireEvent.click(screen.getByText('Apply'));
|
||||
|
||||
// Assert that the onChange function is called with the correct config
|
||||
await waitFor(() => {
|
||||
expect(mockOnChange).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
test('renders the correct input fields based on the selected operator', async () => {
|
||||
render(
|
||||
<FormattingPopoverContent
|
||||
onChange={mockOnChange}
|
||||
columns={columns}
|
||||
extraColorChoices={extraColorChoices}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Select the 'Between' operator
|
||||
fireEvent.change(screen.getAllByLabelText('Operator')[0], {
|
||||
target: { value: Comparator.Between },
|
||||
});
|
||||
fireEvent.click(await screen.findByTitle('< x <'));
|
||||
|
||||
// Assert that the left and right value inputs are rendered
|
||||
expect(await screen.findByLabelText('Left value')).toBeInTheDocument();
|
||||
expect(await screen.findByLabelText('Right value')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('renders None for operator when Green for increase is selected', async () => {
|
||||
render(
|
||||
<FormattingPopoverContent
|
||||
onChange={mockOnChange}
|
||||
columns={columns}
|
||||
extraColorChoices={extraColorChoices}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Select the 'Green for increase' color scheme
|
||||
fireEvent.change(screen.getAllByLabelText(/color scheme/i)[0], {
|
||||
target: { value: ColorSchemeEnum.Green },
|
||||
});
|
||||
|
||||
fireEvent.click(await screen.findByTitle(/green for increase/i));
|
||||
|
||||
// Assert that the operator is set to 'None'
|
||||
expect(screen.getByText(/none/i)).toBeInTheDocument();
|
||||
});
|
||||
@@ -124,7 +124,7 @@ const shouldFormItemUpdate = (
|
||||
isOperatorMultiValue(prevValues.operator) !==
|
||||
isOperatorMultiValue(currentValues.operator);
|
||||
|
||||
const operatorField = (showOnlyNone?: boolean) => (
|
||||
const renderOperator = ({ showOnlyNone }: { showOnlyNone?: boolean } = {}) => (
|
||||
<FormItem
|
||||
name="operator"
|
||||
label={t('Operator')}
|
||||
@@ -141,7 +141,7 @@ const operatorField = (showOnlyNone?: boolean) => (
|
||||
const renderOperatorFields = ({ getFieldValue }: GetFieldValue) =>
|
||||
isOperatorNone(getFieldValue('operator')) ? (
|
||||
<Row gutter={12}>
|
||||
<Col span={6}>{operatorField()}</Col>
|
||||
<Col span={6}>{renderOperator()}</Col>
|
||||
</Row>
|
||||
) : isOperatorMultiValue(getFieldValue('operator')) ? (
|
||||
<Row gutter={12}>
|
||||
@@ -157,7 +157,7 @@ const renderOperatorFields = ({ getFieldValue }: GetFieldValue) =>
|
||||
<FullWidthInputNumber />
|
||||
</FormItem>
|
||||
</Col>
|
||||
<Col span={6}>{operatorField}</Col>
|
||||
<Col span={6}>{renderOperator()}</Col>
|
||||
<Col span={9}>
|
||||
<FormItem
|
||||
name="targetValueRight"
|
||||
@@ -173,7 +173,7 @@ const renderOperatorFields = ({ getFieldValue }: GetFieldValue) =>
|
||||
</Row>
|
||||
) : (
|
||||
<Row gutter={12}>
|
||||
<Col span={6}>{operatorField}</Col>
|
||||
<Col span={6}>{renderOperator()}</Col>
|
||||
<Col span={18}>
|
||||
<FormItem
|
||||
name="targetValue"
|
||||
@@ -248,7 +248,7 @@ export const FormattingPopoverContent = ({
|
||||
renderOperatorFields
|
||||
) : (
|
||||
<Row gutter={12}>
|
||||
<Col span={6}>{operatorField(true)}</Col>
|
||||
<Col span={6}>{renderOperator({ showOnlyNone: true })}</Col>
|
||||
</Row>
|
||||
)}
|
||||
</FormItem>
|
||||
|
||||
74
superset/cli/reset.py
Normal file
74
superset/cli/reset.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import click
|
||||
from flask.cli import with_appcontext
|
||||
from werkzeug.security import check_password_hash
|
||||
|
||||
from superset.cli.lib import feature_flags
|
||||
|
||||
if feature_flags.get("ENABLE_FACTORY_RESET_COMMAND"):
|
||||
|
||||
@click.command()
|
||||
@with_appcontext
|
||||
@click.option("--username", prompt="Admin Username", help="Admin Username")
|
||||
@click.option(
|
||||
"--silent",
|
||||
is_flag=True,
|
||||
prompt=(
|
||||
"Are you sure you want to reset Superset? "
|
||||
"This action cannot be undone. Continue?"
|
||||
),
|
||||
help="Confirmation flag",
|
||||
)
|
||||
@click.option(
|
||||
"--exclude-users",
|
||||
default=None,
|
||||
help="Comma separated list of users to exclude from reset",
|
||||
)
|
||||
@click.option(
|
||||
"--exclude-roles",
|
||||
default=None,
|
||||
help="Comma separated list of roles to exclude from reset",
|
||||
)
|
||||
def factory_reset(
|
||||
username: str, silent: bool, exclude_users: str, exclude_roles: str
|
||||
) -> None:
|
||||
"""Factory Reset Apache Superset"""
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import security_manager
|
||||
from superset.commands.security.reset import ResetSupersetCommand
|
||||
|
||||
# Validate the user
|
||||
password = click.prompt("Admin Password", hide_input=True)
|
||||
user = security_manager.find_user(username)
|
||||
if not user or not check_password_hash(user.password, password):
|
||||
click.secho("Invalid credentials", fg="red")
|
||||
sys.exit(1)
|
||||
if not any(role.name == "Admin" for role in user.roles):
|
||||
click.secho("Permission Denied", fg="red")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
ResetSupersetCommand(silent, user, exclude_users, exclude_roles).run()
|
||||
click.secho("Factory reset complete", fg="green")
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
click.secho(f"Factory reset failed: {ex}", fg="red")
|
||||
sys.exit(1)
|
||||
94
superset/commands/security/reset.py
Normal file
94
superset/commands/security/reset.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
from superset.models.core import Database, FavStar, Log
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResetSupersetCommand(BaseCommand):
|
||||
def __init__(
|
||||
self,
|
||||
confirm: bool,
|
||||
user: Any,
|
||||
exclude_users: Optional[str] = None,
|
||||
exclude_roles: Optional[str] = None,
|
||||
) -> None:
|
||||
self._user = user
|
||||
self._confirm = confirm
|
||||
self._users_to_exclude = ["admin"]
|
||||
if exclude_users:
|
||||
self._users_to_exclude.extend(exclude_users.split(","))
|
||||
self._roles_to_exclude = ["Admin", "Public", "Gamma", "Alpha", "sql_lab"]
|
||||
if exclude_roles:
|
||||
self._roles_to_exclude.extend(exclude_roles.split(","))
|
||||
|
||||
def validate(self) -> None:
|
||||
if not self._confirm:
|
||||
raise Exception("Reset aborted.") # pylint: disable=broad-exception-raised
|
||||
if not self._user or not self._user.is_active:
|
||||
raise Exception("User not found.") # pylint: disable=broad-exception-raised
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
logger.debug("Resetting Superset Started")
|
||||
db.session.query(SqlaTable).delete()
|
||||
databases = db.session.query(Database)
|
||||
for database in databases:
|
||||
db.session.delete(database)
|
||||
db.session.query(Dashboard).delete()
|
||||
db.session.query(Slice).delete()
|
||||
db.session.query(KeyValueEntry).delete()
|
||||
db.session.query(Log).delete()
|
||||
db.session.query(FavStar).delete()
|
||||
|
||||
logger.debug("Ignoring Users: %s", self._users_to_exclude)
|
||||
users_to_delete = (
|
||||
db.session.query(security_manager.user_model)
|
||||
.filter(security_manager.user_model.username.not_in(self._users_to_exclude))
|
||||
.all()
|
||||
)
|
||||
for user in users_to_delete:
|
||||
if not any(role.name == "Admin" for role in user.roles):
|
||||
db.session.delete(user)
|
||||
|
||||
logger.debug("Ignoring Roles: %s", self._roles_to_exclude)
|
||||
roles_to_delete = (
|
||||
db.session.query(security_manager.role_model)
|
||||
.filter(security_manager.role_model.name.not_in(self._roles_to_exclude))
|
||||
.all()
|
||||
)
|
||||
for role in roles_to_delete:
|
||||
db.session.delete(role)
|
||||
|
||||
# Insert new record into Log table
|
||||
log = Log(
|
||||
action="Factory Reset", json="{}", user_id=self._user.id, user=self._user
|
||||
)
|
||||
db.session.add(log)
|
||||
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
logger.debug("Resetting Superset Completed")
|
||||
@@ -539,6 +539,8 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
|
||||
"CHART_PLUGINS_EXPERIMENTAL": False,
|
||||
# Regardless of database configuration settings, force SQLLAB to run async using Celery
|
||||
"SQLLAB_FORCE_RUN_ASYNC": False,
|
||||
# Set to True to to enable factory resent CLI command
|
||||
"ENABLE_FACTORY_RESET_COMMAND": False,
|
||||
}
|
||||
|
||||
# ------------------------------
|
||||
|
||||
@@ -172,7 +172,9 @@ class DatasourceKind(StrEnum):
|
||||
PHYSICAL = "physical"
|
||||
|
||||
|
||||
class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
|
||||
class BaseDatasource(
|
||||
AuditMixinNullable, ImportExportMixin
|
||||
): # pylint: disable=too-many-public-methods
|
||||
"""A common interface to objects that are queryable
|
||||
(tables and datasources)"""
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ CONNECTION_HOST_DOWN_REGEX = re.compile(
|
||||
class MssqlEngineSpec(BaseEngineSpec):
|
||||
engine = "mssql"
|
||||
engine_name = "Microsoft SQL Server"
|
||||
limit_method = LimitMethod.WRAP_SQL
|
||||
limit_method = LimitMethod.FORCE_LIMIT
|
||||
max_column_name_length = 128
|
||||
allows_cte_in_subquery = False
|
||||
allow_limit_clause = False
|
||||
|
||||
@@ -23,7 +23,7 @@ class TeradataEngineSpec(BaseEngineSpec):
|
||||
|
||||
engine = "teradatasql"
|
||||
engine_name = "Teradata"
|
||||
limit_method = LimitMethod.WRAP_SQL
|
||||
limit_method = LimitMethod.FORCE_LIMIT
|
||||
max_column_name_length = 30 # since 14.10 this is 128
|
||||
allow_limit_clause = False
|
||||
select_keywords = {"SELECT", "SEL"}
|
||||
|
||||
@@ -24,7 +24,6 @@ from flask_caching import BaseCache
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||
from superset.key_value.types import (
|
||||
KeyValueCodec,
|
||||
@@ -79,6 +78,9 @@ class SupersetMetastoreCache(BaseCache):
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
KeyValueDAO.upsert_entry(
|
||||
resource=RESOURCE,
|
||||
key=self.get_key(key),
|
||||
@@ -90,6 +92,9 @@ class SupersetMetastoreCache(BaseCache):
|
||||
return True
|
||||
|
||||
def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
try:
|
||||
KeyValueDAO.delete_expired_entries(RESOURCE)
|
||||
KeyValueDAO.create_entry(
|
||||
@@ -106,6 +111,9 @@ class SupersetMetastoreCache(BaseCache):
|
||||
return False
|
||||
|
||||
def get(self, key: str) -> Any:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
return KeyValueDAO.get_value(RESOURCE, self.get_key(key), self.codec)
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
@@ -116,4 +124,7 @@ class SupersetMetastoreCache(BaseCache):
|
||||
|
||||
@transaction()
|
||||
def delete(self, key: str) -> Any:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
return KeyValueDAO.delete_entry(RESOURCE, self.get_key(key))
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
# pylint: disable=too-many-lines
|
||||
"""a collection of model-related helper classes and functions"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import dataclasses
|
||||
import logging
|
||||
@@ -806,7 +808,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
|
||||
def get_sqla_row_level_filters(
|
||||
self,
|
||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||
template_processor: BaseTemplateProcessor | None = None,
|
||||
) -> list[TextClause]:
|
||||
"""
|
||||
Return the appropriate row level security filters for this table and the
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Any, cast, Optional, Union
|
||||
import backoff
|
||||
import msgpack
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from flask import current_app
|
||||
from flask_babel import gettext as __
|
||||
|
||||
from superset import (
|
||||
@@ -52,9 +53,8 @@ from superset.models.sql_lab import Query
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql_parse import (
|
||||
CtasMethod,
|
||||
insert_rls_as_subquery,
|
||||
insert_rls_in_predicate,
|
||||
ParsedQuery,
|
||||
SQLStatement,
|
||||
Table,
|
||||
)
|
||||
from superset.sqllab.limiting_factor import LimitingFactor
|
||||
@@ -128,7 +128,6 @@ def handle_query_error(
|
||||
|
||||
|
||||
def get_query_backoff_handler(details: dict[Any, Any]) -> None:
|
||||
print(details)
|
||||
query_id = details["kwargs"]["query_id"]
|
||||
logger.error(
|
||||
"Query with id `%s` could not be retrieved", str(query_id), exc_info=True
|
||||
@@ -175,22 +174,23 @@ def get_sql_results( # pylint: disable=too-many-arguments
|
||||
log_params: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Executes the sql query returns the results."""
|
||||
with override_user(security_manager.find_user(username)):
|
||||
try:
|
||||
return execute_sql_statements(
|
||||
query_id,
|
||||
rendered_query,
|
||||
return_results,
|
||||
store_results,
|
||||
start_time=start_time,
|
||||
expand_data=expand_data,
|
||||
log_params=log_params,
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logger.debug("Query %d: %s", query_id, ex)
|
||||
stats_logger.incr("error_sqllab_unhandled")
|
||||
query = get_query(query_id)
|
||||
return handle_query_error(ex, query)
|
||||
with current_app.test_request_context():
|
||||
with override_user(security_manager.find_user(username)):
|
||||
try:
|
||||
return execute_sql_statements(
|
||||
query_id,
|
||||
rendered_query,
|
||||
return_results,
|
||||
store_results,
|
||||
start_time=start_time,
|
||||
expand_data=expand_data,
|
||||
log_params=log_params,
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logger.debug("Query %d: %s", query_id, ex)
|
||||
stats_logger.incr("error_sqllab_unhandled")
|
||||
query = get_query(query_id)
|
||||
return handle_query_error(ex, query)
|
||||
|
||||
|
||||
def execute_sql_statement( # pylint: disable=too-many-statements
|
||||
@@ -204,67 +204,49 @@ def execute_sql_statement( # pylint: disable=too-many-statements
|
||||
database: Database = query.database
|
||||
db_engine_spec = database.db_engine_spec
|
||||
|
||||
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
|
||||
parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
|
||||
|
||||
if is_feature_enabled("RLS_IN_SQLLAB"):
|
||||
# There are two ways to insert RLS: either replacing the table with a subquery
|
||||
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
|
||||
# safer, but not supported in all databases.
|
||||
insert_rls = (
|
||||
insert_rls_as_subquery
|
||||
if database.db_engine_spec.allows_subqueries
|
||||
and database.db_engine_spec.allows_alias_in_select
|
||||
else insert_rls_in_predicate
|
||||
)
|
||||
default_schema = database.get_default_schema_for_query(query)
|
||||
parsed_statement = parsed_statement.apply_rls(query.catalog, default_schema)
|
||||
|
||||
# Insert any applicable RLS predicates
|
||||
parsed_query = ParsedQuery(
|
||||
str(
|
||||
insert_rls(
|
||||
parsed_query._parsed[0], # pylint: disable=protected-access
|
||||
database.id,
|
||||
query.schema,
|
||||
)
|
||||
),
|
||||
engine=db_engine_spec.engine,
|
||||
)
|
||||
|
||||
sql = parsed_query.stripped()
|
||||
|
||||
# This is a test to see if the query is being
|
||||
# limited by either the dropdown or the sql.
|
||||
# We are testing to see if more rows exist than the limit.
|
||||
increased_limit = None if query.limit is None else query.limit + 1
|
||||
|
||||
if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml:
|
||||
if parsed_statement.is_dml() and not database.allow_dml:
|
||||
raise SupersetErrorException(
|
||||
SupersetError(
|
||||
message=__("Only SELECT statements are allowed against this database."),
|
||||
message=__("DML statements are not allowed in this database."),
|
||||
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
if apply_ctas:
|
||||
if not query.tmp_table_name:
|
||||
start_dttm = datetime.fromtimestamp(query.start_time)
|
||||
query.tmp_table_name = (
|
||||
f'tmp_{query.user_id}_table_{start_dttm.strftime("%Y_%m_%d_%H_%M_%S")}'
|
||||
)
|
||||
sql = parsed_query.as_create_table(
|
||||
query.tmp_table_name,
|
||||
schema_name=query.tmp_schema_name,
|
||||
parsed_statement = parsed_statement.as_create_table(
|
||||
Table(query.tmp_table_name, query.tmp_schema_name, query.catalog),
|
||||
method=query.ctas_method,
|
||||
)
|
||||
query.select_as_cta_used = True
|
||||
|
||||
increased_limit = None if query.limit is None else query.limit + 1
|
||||
|
||||
# Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
|
||||
if db_engine_spec.is_select_query(parsed_query) and not (
|
||||
if parsed_statement.is_select() and not (
|
||||
query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
|
||||
):
|
||||
if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
|
||||
query.limit = SQL_MAX_ROW
|
||||
sql = apply_limit_if_exists(database, increased_limit, query, sql)
|
||||
|
||||
if query.limit:
|
||||
# Increase limit by one so we can test if there are more rows when the
|
||||
# database returns exactly the number of rows requested by the user.
|
||||
parsed_statement = parsed_statement.apply_limit(increased_limit)
|
||||
|
||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||
sql = parsed_statement.format(strip=True)
|
||||
sql = database.mutate_sql_based_on_config(sql)
|
||||
try:
|
||||
query.executed_sql = sql
|
||||
@@ -332,19 +314,6 @@ def execute_sql_statement( # pylint: disable=too-many-statements
|
||||
return SupersetResultSet(data, cursor_description, db_engine_spec)
|
||||
|
||||
|
||||
def apply_limit_if_exists(
|
||||
database: Database, increased_limit: Optional[int], query: Query, sql: str
|
||||
) -> str:
|
||||
if query.limit and increased_limit:
|
||||
# We are fetching one more than the requested limit in order
|
||||
# to test whether there are more rows than the limit. According to the DB
|
||||
# Engine support it will choose top or limit parse
|
||||
# Later, the extra row will be dropped before sending
|
||||
# the results back to the user.
|
||||
sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
|
||||
return sql
|
||||
|
||||
|
||||
def _serialize_payload(
|
||||
payload: dict[Any, Any], use_msgpack: Optional[bool] = False
|
||||
) -> Union[bytes, str]:
|
||||
|
||||
@@ -32,7 +32,7 @@ import sqlparse
|
||||
from flask_babel import gettext as __
|
||||
from jinja2 import nodes
|
||||
from sqlalchemy import and_
|
||||
from sqlglot import exp, parse, parse_one
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
from sqlglot.errors import ParseError, SqlglotError
|
||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||
@@ -308,7 +308,7 @@ def extract_tables_from_statement(
|
||||
return set()
|
||||
|
||||
try:
|
||||
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
|
||||
pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", dialect=dialect)
|
||||
except ParseError:
|
||||
return set()
|
||||
sources = pseudo_query.find_all(exp.Table)
|
||||
@@ -433,7 +433,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
def format(self, comments: bool = True, strip: bool = False) -> str:
|
||||
"""
|
||||
Format the statement, optionally ommitting comments.
|
||||
"""
|
||||
@@ -451,10 +451,93 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_rls(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
) -> InternalRepresentation:
|
||||
"""
|
||||
Apply Row Level Security to the SQL.
|
||||
|
||||
:param database: The database where the SQL will run
|
||||
:param catalog: The default catalog for non-qualified table names
|
||||
:param schema: The default schema for non-qualified table names
|
||||
:return: The SQL with RLS applied
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.format()
|
||||
|
||||
|
||||
class RLSAsPredicate:
|
||||
"""
|
||||
Apply Row Level Security role as a predicate.
|
||||
|
||||
This transformer will apply any RLS predicates to the relevant tables. For example,
|
||||
given the RLS rule:
|
||||
|
||||
table: some_table
|
||||
clause: id = 42
|
||||
|
||||
If a user subject to the rule runs the following query:
|
||||
|
||||
SELECT foo FROM some_table WHERE bar = 'baz'
|
||||
|
||||
The query will be modified to:
|
||||
|
||||
SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
|
||||
|
||||
This approach is probably less secure than using subqueries, so it's only used for
|
||||
databases without support for subqueries.
|
||||
"""
|
||||
|
||||
def __init__(self, rules: dict[Table, str]) -> None:
|
||||
self.rules = rules
|
||||
|
||||
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||
if not isinstance(node, exp.Select):
|
||||
return node
|
||||
|
||||
table_node = node.find(exp.Table)
|
||||
if not table_node:
|
||||
return node
|
||||
|
||||
table = Table(
|
||||
str(table_node.this),
|
||||
str(table_node.db) if table_node.db else None,
|
||||
str(table_node.catalog) if table_node.catalog else None,
|
||||
)
|
||||
if predicate := self.rules.get(table):
|
||||
if where := node.args.get("where"):
|
||||
predicate = exp.And(this=predicate, expression=where.this)
|
||||
|
||||
node.set("where", exp.Where(this=predicate))
|
||||
|
||||
return node
|
||||
|
||||
|
||||
class RLSAsSubquery:
|
||||
def __init__(self, rules: dict[Table, str]) -> None:
|
||||
self.rules = rules
|
||||
|
||||
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||
if not isinstance(node, exp.Table):
|
||||
return node
|
||||
|
||||
table = Table(
|
||||
str(node.this),
|
||||
str(node.db) if node.db else None,
|
||||
str(node.catalog) if node.catalog else None,
|
||||
)
|
||||
if predicate := self.rules.get(table):
|
||||
alias = node.alias
|
||||
node.set("alias", None)
|
||||
return f"(SELECT * FROM {node} WHERE {predicate}) AS {alias}"
|
||||
|
||||
return node
|
||||
|
||||
|
||||
class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
"""
|
||||
A SQL statement.
|
||||
@@ -521,12 +604,19 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
return extract_tables_from_statement(parsed, dialect)
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
def format(self, comments: bool = True, strip: bool = False) -> str:
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
"""
|
||||
write = Dialect.get_or_raise(self._dialect)
|
||||
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
|
||||
output = write.generate(
|
||||
self._parsed,
|
||||
copy=False,
|
||||
comments=comments,
|
||||
pretty=True,
|
||||
)
|
||||
|
||||
return output.strip(" \t\r\n;") if strip else output
|
||||
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
@@ -543,6 +633,186 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
for eq in set_item.find_all(exp.EQ)
|
||||
}
|
||||
|
||||
def apply_rls(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
) -> SQLStatement:
|
||||
"""
|
||||
Apply Row Level Security to the SQL.
|
||||
|
||||
:param catalog: The default catalog for non-qualified table names
|
||||
:param schema: The default schema for non-qualified table names
|
||||
:return: The SQL with RLS applied
|
||||
"""
|
||||
from superset.db_engine_specs import load_engine_specs
|
||||
|
||||
statement = self._parsed.copy()
|
||||
|
||||
# collect all relevant RLS rules
|
||||
rules = {}
|
||||
for table in self.tables:
|
||||
if rls := self._get_rls_for_table(table, catalog, schema):
|
||||
rules[table] = rls
|
||||
|
||||
if not rules:
|
||||
return statement
|
||||
|
||||
use_subquery = all(
|
||||
engine_spec.allows_subqueries
|
||||
for engine_spec in load_engine_specs()
|
||||
if engine_spec.engine == self.engine
|
||||
)
|
||||
transformer = RLSAsSubquery(rules) if use_subquery else RLSAsPredicate(rules)
|
||||
|
||||
return SQLStatement(statement.transform(transformer), self.engine)
|
||||
|
||||
def _get_rls_for_table(
|
||||
self,
|
||||
database: Database,
|
||||
table: Table,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
) -> exp.Expression | None:
|
||||
"""
|
||||
Get the RLS for a table.
|
||||
|
||||
:param table: The table to get the RLS for
|
||||
:param catalog: The default catalog for non-qualified table names
|
||||
:param schema: The default schema for non-qualified table names
|
||||
:return: The RLS for the table
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
dataset = db.session.query(SqlaTable).filter(
|
||||
and_(
|
||||
SqlaTable.database_id == database.id,
|
||||
SqlaTable.catalog == table.catalog or catalog,
|
||||
SqlaTable.schema == table.schema or schema,
|
||||
SqlaTable.table_name == table.table,
|
||||
).one_or_none()
|
||||
)
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
filters = dataset.get_sqla_row_level_filters()
|
||||
if not filters:
|
||||
return None
|
||||
|
||||
rls = and_(*filters).compile(
|
||||
dialect=database.get_dialect(),
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
|
||||
return sqlglot.parse_one(str(rls), dialect=self._dialect)
|
||||
|
||||
def is_dml(self) -> bool:
|
||||
"""
|
||||
Check if the statement is DML.
|
||||
|
||||
:return: True if the statement is DML
|
||||
"""
|
||||
for node in self._parsed.walk():
|
||||
if isinstance(
|
||||
node,
|
||||
(
|
||||
exp.Insert,
|
||||
exp.Update,
|
||||
exp.Delete,
|
||||
exp.Merge,
|
||||
exp.Create,
|
||||
exp.Alter,
|
||||
exp.Drop,
|
||||
exp.TruncateTable,
|
||||
),
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def as_create_table(self, table: Table, method: CtasMethod) -> SQLStatement:
|
||||
"""
|
||||
Convert the statement to a CREATE TABLE statement.
|
||||
"""
|
||||
create_table = exp.Create(
|
||||
this=sqlglot.parse_one(table, into=exp.Table),
|
||||
kind=method.value,
|
||||
expression=self._parsed.copy(),
|
||||
)
|
||||
|
||||
return SQLStatement(create_table, self.engine)
|
||||
|
||||
def is_select(self) -> bool:
|
||||
"""
|
||||
Check if the statement is a SELECT statement.
|
||||
|
||||
:return: True if the statement is a SELECT statement
|
||||
"""
|
||||
return isinstance(self._parsed, exp.Select)
|
||||
|
||||
def apply_limit(self, limit: int, force: bool = False) -> SQLStatement:
|
||||
"""
|
||||
Apply a limit to the SQL.
|
||||
|
||||
There are 3 strategies to limit queries, defined in the DB engine spec:
|
||||
|
||||
1. `FORCE_LIMIT`: a limit is added to the query, or the existing one is
|
||||
replaced. This is the most efficient, since the database will produce at
|
||||
most the number of rows that Superset will display.
|
||||
2. `WRAP_SQL`: the query is wrapped in a subquery, and the limit is applied
|
||||
to the outer query. This might be inneficient, since the database
|
||||
optimizer might not be able to push the limit down to the inner query.
|
||||
3. `FETCH_MANY`: no limit is applied, but only `LIMIT` rows are fetched from
|
||||
the database. This is the least efficient, unless the database computes
|
||||
rows as they are read by the cursor, which is unlikely.
|
||||
|
||||
:param limit: The limit to apply
|
||||
:param force: Apply limit even when a lower one is present
|
||||
:return: The SQL with the limit applied
|
||||
"""
|
||||
from superset.db_engine_specs import load_engine_specs
|
||||
from superset.db_engine_specs.base import LimitMethod
|
||||
|
||||
methods = {
|
||||
engine_spec.limit_method
|
||||
for engine_spec in load_engine_specs()
|
||||
if engine_spec.engine == self.engine
|
||||
}
|
||||
if not methods:
|
||||
methods = {LimitMethod.FETCH_MANY}
|
||||
|
||||
# When multiple methods are supported, we prefer the more generic one --
|
||||
# usually less efficient.
|
||||
preference = [
|
||||
LimitMethod.FETCH_MANY,
|
||||
LimitMethod.WRAP_SQL,
|
||||
LimitMethod.FORCE_LIMIT,
|
||||
]
|
||||
method = sorted(methods, key=preference.index)[0]
|
||||
|
||||
if not self.is_select() or method == LimitMethod.FETCH_MANY:
|
||||
return SQLStatement(self._parsed.copy(), self.engine)
|
||||
|
||||
if method == LimitMethod.WRAP_SQL:
|
||||
limited = exp.Select(
|
||||
expressions=[exp.Star()],
|
||||
from_=exp.Subquery(subquery=self._parsed.copy(), alias="inner_qry"),
|
||||
limit=exp.Literal.number(limit),
|
||||
)
|
||||
return SQLStatement(limited, self.engine)
|
||||
|
||||
current_limit: int | None = None
|
||||
for node in self._parsed.find_all(exp.Limit):
|
||||
current_limit = int(node.expression.this)
|
||||
break
|
||||
|
||||
if force or current_limit is None or limit < current_limit:
|
||||
return SQLStatement(self._parsed.limit(limit), self.engine)
|
||||
|
||||
return SQLStatement(self._parsed.copy(), self.engine)
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
@@ -666,11 +936,11 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
)
|
||||
return set()
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
def format(self, comments: bool = True, strip: bool = False) -> str:
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
"""
|
||||
return self._parsed
|
||||
return self._parsed.strip(" \t\r\n;") if strip else self._parsed
|
||||
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
@@ -712,7 +982,10 @@ class SQLScript:
|
||||
"""
|
||||
Pretty-format the SQL query.
|
||||
"""
|
||||
return ";\n".join(statement.format(comments) for statement in self.statements)
|
||||
return (
|
||||
";\n".join(statement.format(comments) for statement in self.statements)
|
||||
+ ";"
|
||||
)
|
||||
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
@@ -792,7 +1065,7 @@ class ParsedQuery:
|
||||
Note: this uses sqlglot, since it's better at catching more edge cases.
|
||||
"""
|
||||
try:
|
||||
statements = parse(self.stripped(), dialect=self._dialect)
|
||||
statements = sqlglot.parse(self.stripped(), dialect=self._dialect)
|
||||
except SqlglotError as ex:
|
||||
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
|
||||
|
||||
@@ -846,7 +1119,7 @@ class ParsedQuery:
|
||||
return set()
|
||||
|
||||
try:
|
||||
pseudo_query = parse_one(
|
||||
pseudo_query = sqlglot.parse_one(
|
||||
f"SELECT {literal.this}",
|
||||
dialect=self._dialect,
|
||||
)
|
||||
|
||||
@@ -18,13 +18,14 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from flask.ctx import AppContext
|
||||
from freezegun import freeze_time
|
||||
|
||||
from superset.extensions.metastore_cache import SupersetMetastoreCache
|
||||
from superset.key_value.exceptions import KeyValueCodecEncodeException
|
||||
from superset.key_value.types import (
|
||||
JsonKeyValueCodec,
|
||||
@@ -32,9 +33,6 @@ from superset.key_value.types import (
|
||||
PickleKeyValueCodec,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.extensions.metastore_cache import SupersetMetastoreCache
|
||||
|
||||
NAMESPACE = UUID("ee173d1b-ccf3-40aa-941c-985c15224496")
|
||||
|
||||
FIRST_KEY = "foo"
|
||||
@@ -47,8 +45,6 @@ SECOND_VALUE = "qwerty"
|
||||
|
||||
@pytest.fixture
|
||||
def cache() -> SupersetMetastoreCache:
|
||||
from superset.extensions.metastore_cache import SupersetMetastoreCache
|
||||
|
||||
return SupersetMetastoreCache(
|
||||
namespace=NAMESPACE,
|
||||
default_timeout=600,
|
||||
|
||||
@@ -273,9 +273,9 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
|
||||
"error_type": SupersetErrorType.OAUTH2_REDIRECT,
|
||||
"level": ErrorLevel.WARNING,
|
||||
"extra": {
|
||||
"url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3ASYSADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vZXhhbXBsZS5jb20vYXBpL3YxL2RhdGFiYXNlL29hdXRoMi8iLCJ0YWJfaWQiOiJmYjExZjUyOC02ZWJhLTRhOGEtODM3ZS02YjBkMzllZTkxODcifQ%252Ec_m_35xwwSrLgCXwV4aKO1928flOEFQIqqg9ctiXjcM&redirect_uri=http%3A%2F%2Fexample.com%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id&prompt=consent",
|
||||
"url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3ASYSADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vbG9jYWxob3N0L2FwaS92MS9kYXRhYmFzZS9vYXV0aDIvIiwidGFiX2lkIjoiZmIxMWY1MjgtNmViYS00YThhLTgzN2UtNmIwZDM5ZWU5MTg3In0%252E7nLkei6-V8sVk_Pgm8cFhk0tnKRKayRE1Vc7RxuM9mw&redirect_uri=http%3A%2F%2Flocalhost%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id&prompt=consent",
|
||||
"tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
|
||||
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
|
||||
"redirect_uri": "http://localhost/api/v1/database/oauth2/",
|
||||
},
|
||||
}
|
||||
],
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Optional
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import sqlglot
|
||||
import sqlparse
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import text
|
||||
@@ -41,6 +42,8 @@ from superset.sql_parse import (
|
||||
insert_rls_in_predicate,
|
||||
KustoKQLStatement,
|
||||
ParsedQuery,
|
||||
RLSAsPredicate,
|
||||
RLSAsSubquery,
|
||||
sanitize_clause,
|
||||
split_kql,
|
||||
SQLScript,
|
||||
@@ -119,8 +122,9 @@ def test_extract_tables_subselect() -> None:
|
||||
"""
|
||||
Test that tables inside subselects are parsed correctly.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT sub.*
|
||||
FROM (
|
||||
SELECT *
|
||||
@@ -129,10 +133,13 @@ FROM (
|
||||
) sub, s2.t2
|
||||
WHERE sub.resolution = 'NONE'
|
||||
"""
|
||||
) == {Table("t1", "s1"), Table("t2", "s2")}
|
||||
)
|
||||
== {Table("t1", "s1"), Table("t2", "s2")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT sub.*
|
||||
FROM (
|
||||
SELECT *
|
||||
@@ -141,10 +148,13 @@ FROM (
|
||||
) sub
|
||||
WHERE sub.resolution = 'NONE'
|
||||
"""
|
||||
) == {Table("t1", "s1")}
|
||||
)
|
||||
== {Table("t1", "s1")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT * FROM t1
|
||||
WHERE s11 > ANY (
|
||||
SELECT COUNT(*) /* no hint */ FROM t2
|
||||
@@ -156,7 +166,9 @@ WHERE s11 > ANY (
|
||||
)
|
||||
)
|
||||
"""
|
||||
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
|
||||
)
|
||||
== {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_select_in_expression() -> None:
|
||||
@@ -227,24 +239,30 @@ def test_extract_tables_select_array() -> None:
|
||||
"""
|
||||
Test that queries selecting arrays work as expected.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT ARRAY[1, 2, 3] AS my_array
|
||||
FROM t1 LIMIT 10
|
||||
"""
|
||||
) == {Table("t1")}
|
||||
)
|
||||
== {Table("t1")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_select_if() -> None:
|
||||
"""
|
||||
Test that queries with an ``IF`` work as expected.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
|
||||
FROM t1 LIMIT 10
|
||||
"""
|
||||
) == {Table("t1")}
|
||||
)
|
||||
== {Table("t1")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_with_catalog() -> None:
|
||||
@@ -312,29 +330,38 @@ def test_extract_tables_where_subquery() -> None:
|
||||
"""
|
||||
Test that tables in a ``WHERE`` subquery are parsed correctly.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE regionkey = (SELECT max(regionkey) FROM t2)
|
||||
"""
|
||||
) == {Table("t1"), Table("t2")}
|
||||
)
|
||||
== {Table("t1"), Table("t2")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE regionkey IN (SELECT regionkey FROM t2)
|
||||
"""
|
||||
) == {Table("t1"), Table("t2")}
|
||||
)
|
||||
== {Table("t1"), Table("t2")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
|
||||
"""
|
||||
) == {Table("t1"), Table("t2")}
|
||||
)
|
||||
== {Table("t1"), Table("t2")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_describe() -> None:
|
||||
@@ -348,12 +375,15 @@ def test_extract_tables_show_partitions() -> None:
|
||||
"""
|
||||
Test ``SHOW PARTITIONS``.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SHOW PARTITIONS FROM orders
|
||||
WHERE ds >= '2013-01-01' ORDER BY ds DESC
|
||||
"""
|
||||
) == {Table("orders")}
|
||||
)
|
||||
== {Table("orders")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_join() -> None:
|
||||
@@ -365,8 +395,9 @@ def test_extract_tables_join() -> None:
|
||||
Table("t2"),
|
||||
}
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT a.date, b.name
|
||||
FROM left_table a
|
||||
JOIN (
|
||||
@@ -377,10 +408,13 @@ JOIN (
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
) == {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
== {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT a.date, b.name
|
||||
FROM left_table a
|
||||
LEFT INNER JOIN (
|
||||
@@ -391,10 +425,13 @@ LEFT INNER JOIN (
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
) == {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
== {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT a.date, b.name
|
||||
FROM left_table a
|
||||
RIGHT OUTER JOIN (
|
||||
@@ -405,10 +442,13 @@ RIGHT OUTER JOIN (
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
) == {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
== {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT a.date, b.name
|
||||
FROM left_table a
|
||||
FULL OUTER JOIN (
|
||||
@@ -419,15 +459,18 @@ FULL OUTER JOIN (
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
) == {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
== {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_semi_join() -> None:
|
||||
"""
|
||||
Test ``LEFT SEMI JOIN``.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT a.date, b.name
|
||||
FROM left_table a
|
||||
LEFT SEMI JOIN (
|
||||
@@ -438,15 +481,18 @@ LEFT SEMI JOIN (
|
||||
) b
|
||||
ON a.data = b.date
|
||||
"""
|
||||
) == {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
== {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_combinations() -> None:
|
||||
"""
|
||||
Test a complex case with nested queries.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT * FROM t1
|
||||
WHERE s11 > ANY (
|
||||
SELECT * FROM t1 UNION ALL SELECT * FROM (
|
||||
@@ -460,10 +506,13 @@ WHERE s11 > ANY (
|
||||
)
|
||||
)
|
||||
"""
|
||||
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
|
||||
)
|
||||
== {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT * FROM (
|
||||
SELECT * FROM (
|
||||
SELECT * FROM (
|
||||
@@ -472,45 +521,56 @@ SELECT * FROM (
|
||||
) AS S2
|
||||
) AS S3
|
||||
"""
|
||||
) == {Table("EmployeeS")}
|
||||
)
|
||||
== {Table("EmployeeS")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_with() -> None:
|
||||
"""
|
||||
Test ``WITH``.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
WITH
|
||||
x AS (SELECT a FROM t1),
|
||||
y AS (SELECT a AS b FROM t2),
|
||||
z AS (SELECT b AS c FROM t3)
|
||||
SELECT c FROM z
|
||||
"""
|
||||
) == {Table("t1"), Table("t2"), Table("t3")}
|
||||
)
|
||||
== {Table("t1"), Table("t2"), Table("t3")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
WITH
|
||||
x AS (SELECT a FROM t1),
|
||||
y AS (SELECT a AS b FROM x),
|
||||
z AS (SELECT b AS c FROM y)
|
||||
SELECT c FROM z
|
||||
"""
|
||||
) == {Table("t1")}
|
||||
)
|
||||
== {Table("t1")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_reusing_aliases() -> None:
|
||||
"""
|
||||
Test that the parser follows aliases.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
with q1 as ( select key from q2 where key = '5'),
|
||||
q2 as ( select key from src where key = '5')
|
||||
select * from (select key from q1) a
|
||||
"""
|
||||
) == {Table("src")}
|
||||
)
|
||||
== {Table("src")}
|
||||
)
|
||||
|
||||
# weird query with circular dependency
|
||||
assert (
|
||||
@@ -547,8 +607,9 @@ def test_extract_tables_complex() -> None:
|
||||
"""
|
||||
Test a few complex queries.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT sum(m_examples) AS "sum__m_example"
|
||||
FROM (
|
||||
SELECT
|
||||
@@ -569,23 +630,29 @@ FROM (
|
||||
ORDER BY "sum__m_example" DESC
|
||||
LIMIT 10;
|
||||
"""
|
||||
) == {
|
||||
Table("my_l_table"),
|
||||
Table("my_b_table"),
|
||||
Table("my_t_table"),
|
||||
Table("inner_table"),
|
||||
}
|
||||
)
|
||||
== {
|
||||
Table("my_l_table"),
|
||||
Table("my_b_table"),
|
||||
Table("my_t_table"),
|
||||
Table("inner_table"),
|
||||
}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT *
|
||||
FROM table_a AS a, table_b AS b, table_c as c
|
||||
WHERE a.id = b.id and b.id = c.id
|
||||
"""
|
||||
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||
)
|
||||
== {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT somecol AS somecol
|
||||
FROM (
|
||||
WITH bla AS (
|
||||
@@ -629,51 +696,63 @@ FROM (
|
||||
LIMIT 50000
|
||||
)
|
||||
"""
|
||||
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
|
||||
)
|
||||
== {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_mixed_from_clause() -> None:
|
||||
"""
|
||||
Test that the parser handles a ``FROM`` clause with table and subselect.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT *
|
||||
FROM table_a AS a, (select * from table_b) AS b, table_c as c
|
||||
WHERE a.id = b.id and b.id = c.id
|
||||
"""
|
||||
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||
)
|
||||
== {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_nested_select() -> None:
|
||||
"""
|
||||
Test that the parser handles selects inside functions.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
||||
from INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
||||
""",
|
||||
"mysql",
|
||||
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||
"mysql",
|
||||
)
|
||||
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
||||
from INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
|
||||
""",
|
||||
"mysql",
|
||||
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||
"mysql",
|
||||
)
|
||||
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_complex_cte_with_prefix() -> None:
|
||||
"""
|
||||
Test that the parser handles CTEs with prefixes.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
|
||||
AS (
|
||||
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
|
||||
@@ -685,21 +764,26 @@ FROM CTE__test
|
||||
GROUP BY SalesYear, SalesPersonID
|
||||
ORDER BY SalesPersonID, SalesYear;
|
||||
"""
|
||||
) == {Table("SalesOrderHeader")}
|
||||
)
|
||||
== {Table("SalesOrderHeader")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
|
||||
"""
|
||||
Test that aliases that are keywords are parsed correctly.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
WITH
|
||||
f AS (SELECT * FROM foo),
|
||||
match AS (SELECT * FROM f)
|
||||
SELECT * FROM match
|
||||
"""
|
||||
) == {Table("foo")}
|
||||
)
|
||||
== {Table("foo")}
|
||||
)
|
||||
|
||||
|
||||
def test_update() -> None:
|
||||
@@ -1841,7 +1925,7 @@ def test_sqlquery() -> None:
|
||||
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
|
||||
|
||||
assert len(script.statements) == 2
|
||||
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
|
||||
assert script.format() == "SELECT\n 1;\nSELECT\n 2;"
|
||||
assert script.statements[0].format() == "SELECT\n 1"
|
||||
|
||||
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
|
||||
@@ -2058,3 +2142,120 @@ on $left.Day1 == $right.Day
|
||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,rules,expected",
|
||||
[
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t",
|
||||
{Table("some_table"): "id = 42"},
|
||||
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
||||
{Table("some_table"): "id = 42"},
|
||||
(
|
||||
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t "
|
||||
"WHERE bar = 'baz'"
|
||||
),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM schema1.some_table AS t",
|
||||
{Table("some_table", "schema1"): "id = 42"},
|
||||
"SELECT t.foo FROM (SELECT * FROM schema1.some_table WHERE id = 42) AS t",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM schema1.some_table AS t",
|
||||
{Table("some_table", "schema2"): "id = 42"},
|
||||
"SELECT t.foo FROM schema1.some_table AS t",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"SELECT t.foo FROM (SELECT * FROM catalog1.schema1.some_table WHERE id = 42) AS t",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog2"): "id = 42"},
|
||||
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_RLSAsSubquery(sql: str, rules: dict[Table, str], expected: str) -> None:
|
||||
"""
|
||||
Test the `RLSAsSubquery` transformer.
|
||||
"""
|
||||
statement = sqlglot.parse_one(sql)
|
||||
transformer = RLSAsSubquery(rules)
|
||||
assert str(statement.transform(transformer)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,rules,expected",
|
||||
[
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t",
|
||||
{Table("some_table"): "id = 42"},
|
||||
"SELECT t.foo FROM some_table AS t WHERE id = 42",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
||||
{Table("some_table"): "id = 42"},
|
||||
"SELECT t.foo FROM some_table AS t WHERE id = 42 AND bar = 'baz'",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_RLSAsPredicate(sql: str, rules: dict[Table, str], expected: str) -> None:
|
||||
"""
|
||||
Test the `RLSAsPredicate` transformer.
|
||||
"""
|
||||
statement = sqlglot.parse_one(sql)
|
||||
transformer = RLSAsPredicate(rules)
|
||||
assert str(statement.transform(transformer)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,engine,limit,force,expected",
|
||||
[
|
||||
(
|
||||
"SELECT TOP 10 * FROM Customers",
|
||||
"teradatasql",
|
||||
5,
|
||||
False,
|
||||
"SELECT\nTOP 5\n *\nFROM Customers",
|
||||
),
|
||||
(
|
||||
"SELECT TOP 10 * FROM Customers",
|
||||
"teradatasql",
|
||||
15,
|
||||
False,
|
||||
"SELECT\nTOP 10\n *\nFROM Customers",
|
||||
),
|
||||
(
|
||||
"SELECT TOP 10 * FROM Customers",
|
||||
"teradatasql",
|
||||
15,
|
||||
True,
|
||||
"SELECT\nTOP 15\n *\nFROM Customers",
|
||||
),
|
||||
(
|
||||
"SELECT TOP 10 * FROM Customers",
|
||||
"mssql",
|
||||
15,
|
||||
True,
|
||||
"SELECT\nTOP 15\n *\nFROM Customers",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_apply_limit(
|
||||
sql: str,
|
||||
engine: str,
|
||||
limit: int,
|
||||
force: bool,
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Test the `apply_limit` function.
|
||||
"""
|
||||
assert SQLStatement(sql, engine).apply_limit(limit, force).format() == expected
|
||||
|
||||
Reference in New Issue
Block a user