Compare commits

...

10 Commits

Author SHA1 Message Date
Beto Dealmeida
7e7b9f84aa WIP 2024-07-03 15:44:01 -04:00
Tyler Eason
5231e86b6c docs(docker compose): fix step 4 list formatting (#29468) 2024-07-03 11:39:11 -06:00
mknadh
6b73b69b41 feat(CLI command): Apache Superset "Factory Reset" CLI command #27207 (#27221) 2024-07-03 09:20:05 -07:00
Michael S. Molina
35da6ac270 fix: Dashboard hangs when initial filters cannot be loaded (#29456) 2024-07-03 09:16:07 -03:00
Beto Dealmeida
d5c0506faa fix: OAuth2 in async DBs (#29461) 2024-07-02 21:12:07 -04:00
Joe Li
fb1f2c4f18 fix: re-add missing code from PR #28132 (#29446) 2024-07-02 16:58:55 -07:00
Ville Brofeldt
7f3c8efab0 fix(metastore-cache): import dao in methods (#29451) 2024-07-02 15:28:42 +03:00
dependabot[bot]
7bb7fc0f49 chore(deps): bump deck.gl from 9.0.12 to 9.0.20 in /superset-frontend/plugins/legacy-preset-chart-deckgl (#29426)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-01 21:21:44 -06:00
dependabot[bot]
3449b8f9dc chore(deps-dev): update @types/lodash requirement from ^4.17.4 to ^4.17.6 in /superset-frontend/plugins/plugin-chart-handlebars (#29425)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Evan Rusackas <evan@rusackas.com>
2024-07-01 17:02:44 -06:00
dependabot[bot]
7a0ae36c4a chore(deps): bump actions/checkout from 2 to 4 (#29434)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-07-01 15:16:11 -07:00
23 changed files with 987 additions and 218 deletions

View File

@@ -66,20 +66,20 @@ jobs:
# Conditional checkout based on context
- name: Checkout for push or pull_request event
if: github.event_name == 'push' || github.event_name == 'pull_request'
uses: actions/checkout@v2
uses: actions/checkout@v4
with:
persist-credentials: false
submodules: recursive
- name: Checkout using ref (workflow_dispatch)
if: github.event_name == 'workflow_dispatch' && github.event.inputs.ref != ''
uses: actions/checkout@v2
uses: actions/checkout@v4
with:
persist-credentials: false
ref: ${{ github.event.inputs.ref }}
submodules: recursive
- name: Checkout using PR ID (workflow_dispatch)
if: github.event_name == 'workflow_dispatch' && github.event.inputs.pr_id != ''
uses: actions/checkout@v2
uses: actions/checkout@v4
with:
persist-credentials: false
ref: refs/pull/${{ github.event.inputs.pr_id }}/merge

View File

@@ -214,13 +214,14 @@ connections from the Docker involves making one-line changes to the files `postg
`pg_hba.conf`; you can find helpful links tailored to your OS / PG version on the web easily for
this task. For Docker it suffices to only whitelist IPs `172.0.0.0/8` instead of `*`, but in any
case you are _warned_ that doing this in a production database _may_ have disastrous consequences as
you are opening your database to the public internet. 2. Instead of `localhost`, try using
`host.docker.internal` (Mac users, Ubuntu) or `172.18.0.1` (Linux users) as the hostname when
attempting to connect to the database. This is a Docker internal detail -- what is happening is
that, in Mac systems, Docker Desktop creates a dns entry for the hostname `host.docker.internal`
which resolves to the correct address for the host machine, whereas in Linux this is not the case
(at least by default). If neither of these 2 hostnames work then you may want to find the exact
hostname you want to use, for that you can do `ifconfig` or `ip addr show` and look at the IP
address of `docker0` interface that must have been created by Docker for you. Alternately if you
don't even see the `docker0` interface try (if needed with sudo) `docker network inspect bridge` and
see if there is an entry for `"Gateway"` and note the IP address.
you are opening your database to the public internet.
1. Instead of `localhost`, try using `host.docker.internal` (Mac users, Ubuntu) or `172.18.0.1`
(Linux users) as the hostname when attempting to connect to the database. This is a Docker internal
detail -- what is happening is that, in Mac systems, Docker Desktop creates a dns entry for the
hostname `host.docker.internal` which resolves to the correct address for the host machine, whereas
in Linux this is not the case (at least by default). If neither of these 2 hostnames work then you
may want to find the exact hostname you want to use, for that you can do `ifconfig` or
`ip addr show` and look at the IP address of `docker0` interface that must have been created by
Docker for you. Alternately if you don't even see the `docker0` interface try (if needed with sudo)
`docker network inspect bridge` and see if there is an entry for `"Gateway"` and note the IP
address.

View File

@@ -70520,7 +70520,7 @@
},
"devDependencies": {
"@types/jest": "^29.5.12",
"@types/lodash": "^4.17.4",
"@types/lodash": "^4.17.6",
"jest": "^29.7.0"
},
"peerDependencies": {
@@ -70535,9 +70535,9 @@
}
},
"plugins/plugin-chart-handlebars/node_modules/@types/lodash": {
"version": "4.17.4",
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.4.tgz",
"integrity": "sha512-wYCP26ZLxaT3R39kiN2+HcJ4kTd3U1waI/cY7ivWYqFP6pW3ZNpvi6Wd6PHZx7T/t8z0vlkXMg3QYLa7DZ/IJQ==",
"version": "4.17.6",
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.6.tgz",
"integrity": "sha512-OpXEVoCKSS3lQqjx9GGGOapBeuW5eUboYHRlHP9urXPX25IKZ6AnP5ZRxtVf63iieUbsHxLn8NQ5Nlftc6yzAA==",
"dev": true
},
"plugins/plugin-chart-handlebars/node_modules/just-handlebars-helpers": {
@@ -89053,16 +89053,16 @@
"version": "file:plugins/plugin-chart-handlebars",
"requires": {
"@types/jest": "^29.5.12",
"@types/lodash": "^4.17.4",
"@types/lodash": "^4.17.6",
"handlebars": "^4.7.7",
"jest": "^29.7.0",
"just-handlebars-helpers": "^1.0.19"
},
"dependencies": {
"@types/lodash": {
"version": "4.17.4",
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.4.tgz",
"integrity": "sha512-wYCP26ZLxaT3R39kiN2+HcJ4kTd3U1waI/cY7ivWYqFP6pW3ZNpvi6Wd6PHZx7T/t8z0vlkXMg3QYLa7DZ/IJQ==",
"version": "4.17.6",
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.6.tgz",
"integrity": "sha512-OpXEVoCKSS3lQqjx9GGGOapBeuW5eUboYHRlHP9urXPX25IKZ6AnP5ZRxtVf63iieUbsHxLn8NQ5Nlftc6yzAA==",
"dev": true
},
"just-handlebars-helpers": {

View File

@@ -31,7 +31,7 @@
"d3-array": "^1.2.4",
"d3-color": "^1.4.1",
"d3-scale": "^3.0.0",
"deck.gl": "9.0.12",
"deck.gl": "9.0.20",
"lodash": "^4.17.21",
"moment": "^2.30.1",
"mousetrap": "^1.6.5",

View File

@@ -42,7 +42,7 @@
},
"devDependencies": {
"@types/jest": "^29.5.12",
"@types/lodash": "^4.17.4",
"@types/lodash": "^4.17.6",
"jest": "^29.7.0"
}
}

View File

@@ -45,7 +45,7 @@ const StyledTitle = styled.span`
interface BasicErrorAlertProps {
title: string;
body: string;
level: ErrorLevel;
level?: ErrorLevel;
}
export default function BasicErrorAlert({

View File

@@ -76,6 +76,7 @@ import {
OPEN_FILTER_BAR_WIDTH,
EMPTY_CONTAINER_Z_INDEX,
} from 'src/dashboard/constants';
import BasicErrorAlert from 'src/components/ErrorMessage/BasicErrorAlert';
import { getRootLevelTabsComponent, shouldFocusTabs } from './utils';
import DashboardContainer from './DashboardContainer';
import { useNativeFilters } from './state';
@@ -462,6 +463,7 @@ const DashboardBuilder: FC<DashboardBuilderProps> = () => {
const {
showDashboard,
missingInitialFilters,
dashboardFiltersOpen,
toggleDashboardFiltersOpen,
nativeFiltersEnabled,
@@ -673,7 +675,30 @@ const DashboardBuilder: FC<DashboardBuilderProps> = () => {
editMode={editMode}
marginLeft={dashboardContentMarginLeft}
>
{showDashboard ? (
{missingInitialFilters.length > 0 ? (
<div
css={css`
display: flex;
flex-direction: row;
align-items: center;
justify-content: center;
flex: 1;
& div {
width: 500px;
}
`}
>
<BasicErrorAlert
title={t('Unable to load dashboard')}
body={t(
`The following filters have the 'Select first filter value by default'
option checked and could not be loaded, which is preventing the dashboard
from rendering: %s`,
missingInitialFilters.join(', '),
)}
/>
</div>
) : showDashboard ? (
<DashboardContainer topLevelTabs={topLevelTabs} />
) : (
<Loading />

View File

@@ -47,17 +47,14 @@ export const useNativeFilters = () => {
filter => filter.requiredFirst,
);
const dataMask = useNativeFiltersDataMask();
const missingInitialFilters = requiredFirstFilter
.filter(({ id }) => dataMask[id]?.filterState?.value === undefined)
.map(({ name }) => name);
const showDashboard =
isInitialized ||
!nativeFiltersEnabled ||
!(
nativeFiltersEnabled &&
requiredFirstFilter.length &&
requiredFirstFilter.find(
({ id }) => dataMask[id]?.filterState?.value === undefined,
)
);
missingInitialFilters.length === 0;
const toggleDashboardFiltersOpen = useCallback(
(visible?: boolean) => {
setDashboardFiltersOpen(visible ?? !dashboardFiltersOpen);
@@ -84,6 +81,7 @@ export const useNativeFilters = () => {
return {
showDashboard,
missingInitialFilters,
dashboardFiltersOpen,
toggleDashboardFiltersOpen,
nativeFiltersEnabled,

View File

@@ -0,0 +1,121 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import {
render,
screen,
fireEvent,
waitFor,
} from 'spec/helpers/testing-library';
import { Comparator } from '@superset-ui/chart-controls';
import { ColorSchemeEnum } from '@superset-ui/plugin-chart-table';
import { FormattingPopoverContent } from './FormattingPopoverContent';
const mockOnChange = jest.fn();
const columns = [
{ label: 'Column 1', value: 'column1' },
{ label: 'Column 2', value: 'column2' },
];
const extraColorChoices = [
{
value: ColorSchemeEnum.Green,
label: 'Green for increase, red for decrease',
},
{
value: ColorSchemeEnum.Red,
label: 'Red for increase, green for decrease',
},
];
test('renders FormattingPopoverContent component', () => {
render(
<FormattingPopoverContent
onChange={mockOnChange}
columns={columns}
extraColorChoices={extraColorChoices}
/>,
);
// Assert that the component renders correctly
expect(screen.getByLabelText('Column')).toBeInTheDocument();
expect(screen.getAllByLabelText('Color scheme')).toHaveLength(2);
expect(screen.getAllByLabelText('Operator')).toHaveLength(2);
expect(screen.queryByLabelText('Left value')).not.toBeInTheDocument();
expect(screen.queryByLabelText('Right value')).not.toBeInTheDocument();
expect(screen.getByText('Apply')).toBeInTheDocument();
});
test('calls onChange when Apply button is clicked', async () => {
render(
<FormattingPopoverContent
onChange={mockOnChange}
columns={columns}
extraColorChoices={extraColorChoices}
/>,
);
// Simulate user interaction by clicking the Apply button
fireEvent.click(screen.getByText('Apply'));
// Assert that the onChange function is called with the correct config
await waitFor(() => {
expect(mockOnChange).toHaveBeenCalled();
});
});
test('renders the correct input fields based on the selected operator', async () => {
render(
<FormattingPopoverContent
onChange={mockOnChange}
columns={columns}
extraColorChoices={extraColorChoices}
/>,
);
// Select the 'Between' operator
fireEvent.change(screen.getAllByLabelText('Operator')[0], {
target: { value: Comparator.Between },
});
fireEvent.click(await screen.findByTitle('< x <'));
// Assert that the left and right value inputs are rendered
expect(await screen.findByLabelText('Left value')).toBeInTheDocument();
expect(await screen.findByLabelText('Right value')).toBeInTheDocument();
});
test('renders None for operator when Green for increase is selected', async () => {
render(
<FormattingPopoverContent
onChange={mockOnChange}
columns={columns}
extraColorChoices={extraColorChoices}
/>,
);
// Select the 'Green for increase' color scheme
fireEvent.change(screen.getAllByLabelText(/color scheme/i)[0], {
target: { value: ColorSchemeEnum.Green },
});
fireEvent.click(await screen.findByTitle(/green for increase/i));
// Assert that the operator is set to 'None'
expect(screen.getByText(/none/i)).toBeInTheDocument();
});

View File

@@ -124,7 +124,7 @@ const shouldFormItemUpdate = (
isOperatorMultiValue(prevValues.operator) !==
isOperatorMultiValue(currentValues.operator);
const operatorField = (showOnlyNone?: boolean) => (
const renderOperator = ({ showOnlyNone }: { showOnlyNone?: boolean } = {}) => (
<FormItem
name="operator"
label={t('Operator')}
@@ -141,7 +141,7 @@ const operatorField = (showOnlyNone?: boolean) => (
const renderOperatorFields = ({ getFieldValue }: GetFieldValue) =>
isOperatorNone(getFieldValue('operator')) ? (
<Row gutter={12}>
<Col span={6}>{operatorField()}</Col>
<Col span={6}>{renderOperator()}</Col>
</Row>
) : isOperatorMultiValue(getFieldValue('operator')) ? (
<Row gutter={12}>
@@ -157,7 +157,7 @@ const renderOperatorFields = ({ getFieldValue }: GetFieldValue) =>
<FullWidthInputNumber />
</FormItem>
</Col>
<Col span={6}>{operatorField}</Col>
<Col span={6}>{renderOperator()}</Col>
<Col span={9}>
<FormItem
name="targetValueRight"
@@ -173,7 +173,7 @@ const renderOperatorFields = ({ getFieldValue }: GetFieldValue) =>
</Row>
) : (
<Row gutter={12}>
<Col span={6}>{operatorField}</Col>
<Col span={6}>{renderOperator()}</Col>
<Col span={18}>
<FormItem
name="targetValue"
@@ -248,7 +248,7 @@ export const FormattingPopoverContent = ({
renderOperatorFields
) : (
<Row gutter={12}>
<Col span={6}>{operatorField(true)}</Col>
<Col span={6}>{renderOperator({ showOnlyNone: true })}</Col>
</Row>
)}
</FormItem>

74
superset/cli/reset.py Normal file
View File

@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
import click
from flask.cli import with_appcontext
from werkzeug.security import check_password_hash
from superset.cli.lib import feature_flags
if feature_flags.get("ENABLE_FACTORY_RESET_COMMAND"):
@click.command()
@with_appcontext
@click.option("--username", prompt="Admin Username", help="Admin Username")
@click.option(
"--silent",
is_flag=True,
prompt=(
"Are you sure you want to reset Superset? "
"This action cannot be undone. Continue?"
),
help="Confirmation flag",
)
@click.option(
"--exclude-users",
default=None,
help="Comma separated list of users to exclude from reset",
)
@click.option(
"--exclude-roles",
default=None,
help="Comma separated list of roles to exclude from reset",
)
def factory_reset(
username: str, silent: bool, exclude_users: str, exclude_roles: str
) -> None:
"""Factory Reset Apache Superset"""
# pylint: disable=import-outside-toplevel
from superset import security_manager
from superset.commands.security.reset import ResetSupersetCommand
# Validate the user
password = click.prompt("Admin Password", hide_input=True)
user = security_manager.find_user(username)
if not user or not check_password_hash(user.password, password):
click.secho("Invalid credentials", fg="red")
sys.exit(1)
if not any(role.name == "Admin" for role in user.roles):
click.secho("Permission Denied", fg="red")
sys.exit(1)
try:
ResetSupersetCommand(silent, user, exclude_users, exclude_roles).run()
click.secho("Factory reset complete", fg="green")
except Exception as ex: # pylint: disable=broad-except
click.secho(f"Factory reset failed: {ex}", fg="red")
sys.exit(1)

View File

@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Optional
from superset import db, security_manager
from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable
from superset.key_value.models import KeyValueEntry
from superset.models.core import Database, FavStar, Log
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
logger = logging.getLogger(__name__)
class ResetSupersetCommand(BaseCommand):
def __init__(
self,
confirm: bool,
user: Any,
exclude_users: Optional[str] = None,
exclude_roles: Optional[str] = None,
) -> None:
self._user = user
self._confirm = confirm
self._users_to_exclude = ["admin"]
if exclude_users:
self._users_to_exclude.extend(exclude_users.split(","))
self._roles_to_exclude = ["Admin", "Public", "Gamma", "Alpha", "sql_lab"]
if exclude_roles:
self._roles_to_exclude.extend(exclude_roles.split(","))
def validate(self) -> None:
if not self._confirm:
raise Exception("Reset aborted.") # pylint: disable=broad-exception-raised
if not self._user or not self._user.is_active:
raise Exception("User not found.") # pylint: disable=broad-exception-raised
def run(self) -> None:
self.validate()
logger.debug("Resetting Superset Started")
db.session.query(SqlaTable).delete()
databases = db.session.query(Database)
for database in databases:
db.session.delete(database)
db.session.query(Dashboard).delete()
db.session.query(Slice).delete()
db.session.query(KeyValueEntry).delete()
db.session.query(Log).delete()
db.session.query(FavStar).delete()
logger.debug("Ignoring Users: %s", self._users_to_exclude)
users_to_delete = (
db.session.query(security_manager.user_model)
.filter(security_manager.user_model.username.not_in(self._users_to_exclude))
.all()
)
for user in users_to_delete:
if not any(role.name == "Admin" for role in user.roles):
db.session.delete(user)
logger.debug("Ignoring Roles: %s", self._roles_to_exclude)
roles_to_delete = (
db.session.query(security_manager.role_model)
.filter(security_manager.role_model.name.not_in(self._roles_to_exclude))
.all()
)
for role in roles_to_delete:
db.session.delete(role)
# Insert new record into Log table
log = Log(
action="Factory Reset", json="{}", user_id=self._user.id, user=self._user
)
db.session.add(log)
db.session.commit() # pylint: disable=consider-using-transaction
logger.debug("Resetting Superset Completed")

View File

@@ -539,6 +539,8 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
"CHART_PLUGINS_EXPERIMENTAL": False,
# Regardless of database configuration settings, force SQLLAB to run async using Celery
"SQLLAB_FORCE_RUN_ASYNC": False,
# Set to True to to enable factory resent CLI command
"ENABLE_FACTORY_RESET_COMMAND": False,
}
# ------------------------------

View File

@@ -172,7 +172,9 @@ class DatasourceKind(StrEnum):
PHYSICAL = "physical"
class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
class BaseDatasource(
AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods
"""A common interface to objects that are queryable
(tables and datasources)"""

View File

@@ -49,7 +49,7 @@ CONNECTION_HOST_DOWN_REGEX = re.compile(
class MssqlEngineSpec(BaseEngineSpec):
engine = "mssql"
engine_name = "Microsoft SQL Server"
limit_method = LimitMethod.WRAP_SQL
limit_method = LimitMethod.FORCE_LIMIT
max_column_name_length = 128
allows_cte_in_subquery = False
allow_limit_clause = False

View File

@@ -23,7 +23,7 @@ class TeradataEngineSpec(BaseEngineSpec):
engine = "teradatasql"
engine_name = "Teradata"
limit_method = LimitMethod.WRAP_SQL
limit_method = LimitMethod.FORCE_LIMIT
max_column_name_length = 30 # since 14.10 this is 128
allow_limit_clause = False
select_keywords = {"SELECT", "SEL"}

View File

@@ -24,7 +24,6 @@ from flask_caching import BaseCache
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.daos.key_value import KeyValueDAO
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.types import (
KeyValueCodec,
@@ -79,6 +78,9 @@ class SupersetMetastoreCache(BaseCache):
return None
def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
# pylint: disable=import-outside-toplevel
from superset.daos.key_value import KeyValueDAO
KeyValueDAO.upsert_entry(
resource=RESOURCE,
key=self.get_key(key),
@@ -90,6 +92,9 @@ class SupersetMetastoreCache(BaseCache):
return True
def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
# pylint: disable=import-outside-toplevel
from superset.daos.key_value import KeyValueDAO
try:
KeyValueDAO.delete_expired_entries(RESOURCE)
KeyValueDAO.create_entry(
@@ -106,6 +111,9 @@ class SupersetMetastoreCache(BaseCache):
return False
def get(self, key: str) -> Any:
# pylint: disable=import-outside-toplevel
from superset.daos.key_value import KeyValueDAO
return KeyValueDAO.get_value(RESOURCE, self.get_key(key), self.codec)
def has(self, key: str) -> bool:
@@ -116,4 +124,7 @@ class SupersetMetastoreCache(BaseCache):
@transaction()
def delete(self, key: str) -> Any:
# pylint: disable=import-outside-toplevel
from superset.daos.key_value import KeyValueDAO
return KeyValueDAO.delete_entry(RESOURCE, self.get_key(key))

View File

@@ -17,6 +17,8 @@
# pylint: disable=too-many-lines
"""a collection of model-related helper classes and functions"""
from __future__ import annotations
import builtins
import dataclasses
import logging
@@ -806,7 +808,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def get_sqla_row_level_filters(
self,
template_processor: Optional[BaseTemplateProcessor] = None,
template_processor: BaseTemplateProcessor | None = None,
) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the

View File

@@ -26,6 +26,7 @@ from typing import Any, cast, Optional, Union
import backoff
import msgpack
from celery.exceptions import SoftTimeLimitExceeded
from flask import current_app
from flask_babel import gettext as __
from superset import (
@@ -52,9 +53,8 @@ from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql_parse import (
CtasMethod,
insert_rls_as_subquery,
insert_rls_in_predicate,
ParsedQuery,
SQLStatement,
Table,
)
from superset.sqllab.limiting_factor import LimitingFactor
@@ -128,7 +128,6 @@ def handle_query_error(
def get_query_backoff_handler(details: dict[Any, Any]) -> None:
print(details)
query_id = details["kwargs"]["query_id"]
logger.error(
"Query with id `%s` could not be retrieved", str(query_id), exc_info=True
@@ -175,22 +174,23 @@ def get_sql_results( # pylint: disable=too-many-arguments
log_params: Optional[dict[str, Any]] = None,
) -> Optional[dict[str, Any]]:
"""Executes the sql query returns the results."""
with override_user(security_manager.find_user(username)):
try:
return execute_sql_statements(
query_id,
rendered_query,
return_results,
store_results,
start_time=start_time,
expand_data=expand_data,
log_params=log_params,
)
except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled")
query = get_query(query_id)
return handle_query_error(ex, query)
with current_app.test_request_context():
with override_user(security_manager.find_user(username)):
try:
return execute_sql_statements(
query_id,
rendered_query,
return_results,
store_results,
start_time=start_time,
expand_data=expand_data,
log_params=log_params,
)
except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled")
query = get_query(query_id)
return handle_query_error(ex, query)
def execute_sql_statement( # pylint: disable=too-many-statements
@@ -204,67 +204,49 @@ def execute_sql_statement( # pylint: disable=too-many-statements
database: Database = query.database
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
if is_feature_enabled("RLS_IN_SQLLAB"):
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
# safer, but not supported in all databases.
insert_rls = (
insert_rls_as_subquery
if database.db_engine_spec.allows_subqueries
and database.db_engine_spec.allows_alias_in_select
else insert_rls_in_predicate
)
default_schema = database.get_default_schema_for_query(query)
parsed_statement = parsed_statement.apply_rls(query.catalog, default_schema)
# Insert any applicable RLS predicates
parsed_query = ParsedQuery(
str(
insert_rls(
parsed_query._parsed[0], # pylint: disable=protected-access
database.id,
query.schema,
)
),
engine=db_engine_spec.engine,
)
sql = parsed_query.stripped()
# This is a test to see if the query is being
# limited by either the dropdown or the sql.
# We are testing to see if more rows exist than the limit.
increased_limit = None if query.limit is None else query.limit + 1
if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml:
if parsed_statement.is_dml() and not database.allow_dml:
raise SupersetErrorException(
SupersetError(
message=__("Only SELECT statements are allowed against this database."),
message=__("DML statements are not allowed in this database."),
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
level=ErrorLevel.ERROR,
)
)
if apply_ctas:
if not query.tmp_table_name:
start_dttm = datetime.fromtimestamp(query.start_time)
query.tmp_table_name = (
f'tmp_{query.user_id}_table_{start_dttm.strftime("%Y_%m_%d_%H_%M_%S")}'
)
sql = parsed_query.as_create_table(
query.tmp_table_name,
schema_name=query.tmp_schema_name,
parsed_statement = parsed_statement.as_create_table(
Table(query.tmp_table_name, query.tmp_schema_name, query.catalog),
method=query.ctas_method,
)
query.select_as_cta_used = True
increased_limit = None if query.limit is None else query.limit + 1
# Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
if db_engine_spec.is_select_query(parsed_query) and not (
if parsed_statement.is_select() and not (
query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
):
if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
query.limit = SQL_MAX_ROW
sql = apply_limit_if_exists(database, increased_limit, query, sql)
if query.limit:
# Increase limit by one so we can test if there are more rows when the
# database returns exactly the number of rows requested by the user.
parsed_statement = parsed_statement.apply_limit(increased_limit)
# Hook to allow environment-specific mutation (usually comments) to the SQL
sql = parsed_statement.format(strip=True)
sql = database.mutate_sql_based_on_config(sql)
try:
query.executed_sql = sql
@@ -332,19 +314,6 @@ def execute_sql_statement( # pylint: disable=too-many-statements
return SupersetResultSet(data, cursor_description, db_engine_spec)
def apply_limit_if_exists(
database: Database, increased_limit: Optional[int], query: Query, sql: str
) -> str:
if query.limit and increased_limit:
# We are fetching one more than the requested limit in order
# to test whether there are more rows than the limit. According to the DB
# Engine support it will choose top or limit parse
# Later, the extra row will be dropped before sending
# the results back to the user.
sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
return sql
def _serialize_payload(
payload: dict[Any, Any], use_msgpack: Optional[bool] = False
) -> Union[bytes, str]:

View File

@@ -32,7 +32,7 @@ import sqlparse
from flask_babel import gettext as __
from jinja2 import nodes
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError, SqlglotError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
@@ -308,7 +308,7 @@ def extract_tables_from_statement(
return set()
try:
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", dialect=dialect)
except ParseError:
return set()
sources = pseudo_query.find_all(exp.Table)
@@ -433,7 +433,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def format(self, comments: bool = True) -> str:
def format(self, comments: bool = True, strip: bool = False) -> str:
"""
Format the statement, optionally ommitting comments.
"""
@@ -451,10 +451,93 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def apply_rls(
self,
catalog: str | None,
schema: str | None,
) -> InternalRepresentation:
"""
Apply Row Level Security to the SQL.
:param database: The database where the SQL will run
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:return: The SQL with RLS applied
"""
raise NotImplementedError()
def __str__(self) -> str:
return self.format()
class RLSAsPredicate:
"""
Apply Row Level Security role as a predicate.
This transformer will apply any RLS predicates to the relevant tables. For example,
given the RLS rule:
table: some_table
clause: id = 42
If a user subject to the rule runs the following query:
SELECT foo FROM some_table WHERE bar = 'baz'
The query will be modified to:
SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
This approach is probably less secure than using subqueries, so it's only used for
databases without support for subqueries.
"""
def __init__(self, rules: dict[Table, str]) -> None:
self.rules = rules
def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Select):
return node
table_node = node.find(exp.Table)
if not table_node:
return node
table = Table(
str(table_node.this),
str(table_node.db) if table_node.db else None,
str(table_node.catalog) if table_node.catalog else None,
)
if predicate := self.rules.get(table):
if where := node.args.get("where"):
predicate = exp.And(this=predicate, expression=where.this)
node.set("where", exp.Where(this=predicate))
return node
class RLSAsSubquery:
def __init__(self, rules: dict[Table, str]) -> None:
self.rules = rules
def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Table):
return node
table = Table(
str(node.this),
str(node.db) if node.db else None,
str(node.catalog) if node.catalog else None,
)
if predicate := self.rules.get(table):
alias = node.alias
node.set("alias", None)
return f"(SELECT * FROM {node} WHERE {predicate}) AS {alias}"
return node
class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
A SQL statement.
@@ -521,12 +604,19 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
dialect = SQLGLOT_DIALECTS.get(engine)
return extract_tables_from_statement(parsed, dialect)
def format(self, comments: bool = True) -> str:
def format(self, comments: bool = True, strip: bool = False) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
output = write.generate(
self._parsed,
copy=False,
comments=comments,
pretty=True,
)
return output.strip(" \t\r\n;") if strip else output
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -543,6 +633,186 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
for eq in set_item.find_all(exp.EQ)
}
def apply_rls(
self,
catalog: str | None,
schema: str | None,
) -> SQLStatement:
"""
Apply Row Level Security to the SQL.
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:return: The SQL with RLS applied
"""
from superset.db_engine_specs import load_engine_specs
statement = self._parsed.copy()
# collect all relevant RLS rules
rules = {}
for table in self.tables:
if rls := self._get_rls_for_table(table, catalog, schema):
rules[table] = rls
if not rules:
return statement
use_subquery = all(
engine_spec.allows_subqueries
for engine_spec in load_engine_specs()
if engine_spec.engine == self.engine
)
transformer = RLSAsSubquery(rules) if use_subquery else RLSAsPredicate(rules)
return SQLStatement(statement.transform(transformer), self.engine)
def _get_rls_for_table(
self,
database: Database,
table: Table,
catalog: str | None,
schema: str | None,
) -> exp.Expression | None:
"""
Get the RLS for a table.
:param table: The table to get the RLS for
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:return: The RLS for the table
"""
# pylint: disable=import-outside-toplevel
from superset import db
from superset.connectors.sqla.models import SqlaTable
dataset = db.session.query(SqlaTable).filter(
and_(
SqlaTable.database_id == database.id,
SqlaTable.catalog == table.catalog or catalog,
SqlaTable.schema == table.schema or schema,
SqlaTable.table_name == table.table,
).one_or_none()
)
if not dataset:
return None
filters = dataset.get_sqla_row_level_filters()
if not filters:
return None
rls = and_(*filters).compile(
dialect=database.get_dialect(),
compile_kwargs={"literal_binds": True},
)
return sqlglot.parse_one(str(rls), dialect=self._dialect)
def is_dml(self) -> bool:
"""
Check if the statement is DML.
:return: True if the statement is DML
"""
for node in self._parsed.walk():
if isinstance(
node,
(
exp.Insert,
exp.Update,
exp.Delete,
exp.Merge,
exp.Create,
exp.Alter,
exp.Drop,
exp.TruncateTable,
),
):
return True
return False
def as_create_table(self, table: Table, method: CtasMethod) -> SQLStatement:
"""
Convert the statement to a CREATE TABLE statement.
"""
create_table = exp.Create(
this=sqlglot.parse_one(table, into=exp.Table),
kind=method.value,
expression=self._parsed.copy(),
)
return SQLStatement(create_table, self.engine)
def is_select(self) -> bool:
"""
Check if the statement is a SELECT statement.
:return: True if the statement is a SELECT statement
"""
return isinstance(self._parsed, exp.Select)
def apply_limit(self, limit: int, force: bool = False) -> SQLStatement:
"""
Apply a limit to the SQL.
There are 3 strategies to limit queries, defined in the DB engine spec:
1. `FORCE_LIMIT`: a limit is added to the query, or the existing one is
replaced. This is the most efficient, since the database will produce at
most the number of rows that Superset will display.
2. `WRAP_SQL`: the query is wrapped in a subquery, and the limit is applied
to the outer query. This might be inneficient, since the database
optimizer might not be able to push the limit down to the inner query.
3. `FETCH_MANY`: no limit is applied, but only `LIMIT` rows are fetched from
the database. This is the least efficient, unless the database computes
rows as they are read by the cursor, which is unlikely.
:param limit: The limit to apply
:param force: Apply limit even when a lower one is present
:return: The SQL with the limit applied
"""
from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.base import LimitMethod
methods = {
engine_spec.limit_method
for engine_spec in load_engine_specs()
if engine_spec.engine == self.engine
}
if not methods:
methods = {LimitMethod.FETCH_MANY}
# When multiple methods are supported, we prefer the more generic one --
# usually less efficient.
preference = [
LimitMethod.FETCH_MANY,
LimitMethod.WRAP_SQL,
LimitMethod.FORCE_LIMIT,
]
method = sorted(methods, key=preference.index)[0]
if not self.is_select() or method == LimitMethod.FETCH_MANY:
return SQLStatement(self._parsed.copy(), self.engine)
if method == LimitMethod.WRAP_SQL:
limited = exp.Select(
expressions=[exp.Star()],
from_=exp.Subquery(subquery=self._parsed.copy(), alias="inner_qry"),
limit=exp.Literal.number(limit),
)
return SQLStatement(limited, self.engine)
current_limit: int | None = None
for node in self._parsed.find_all(exp.Limit):
current_limit = int(node.expression.this)
break
if force or current_limit is None or limit < current_limit:
return SQLStatement(self._parsed.limit(limit), self.engine)
return SQLStatement(self._parsed.copy(), self.engine)
class KQLSplitState(enum.Enum):
"""
@@ -666,11 +936,11 @@ class KustoKQLStatement(BaseSQLStatement[str]):
)
return set()
def format(self, comments: bool = True) -> str:
def format(self, comments: bool = True, strip: bool = False) -> str:
"""
Pretty-format the SQL statement.
"""
return self._parsed
return self._parsed.strip(" \t\r\n;") if strip else self._parsed
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -712,7 +982,10 @@ class SQLScript:
"""
Pretty-format the SQL query.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)
return (
";\n".join(statement.format(comments) for statement in self.statements)
+ ";"
)
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -792,7 +1065,7 @@ class ParsedQuery:
Note: this uses sqlglot, since it's better at catching more edge cases.
"""
try:
statements = parse(self.stripped(), dialect=self._dialect)
statements = sqlglot.parse(self.stripped(), dialect=self._dialect)
except SqlglotError as ex:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
@@ -846,7 +1119,7 @@ class ParsedQuery:
return set()
try:
pseudo_query = parse_one(
pseudo_query = sqlglot.parse_one(
f"SELECT {literal.this}",
dialect=self._dialect,
)

View File

@@ -18,13 +18,14 @@ from __future__ import annotations
from contextlib import nullcontext
from datetime import datetime, timedelta
from typing import Any, TYPE_CHECKING
from typing import Any
from uuid import UUID
import pytest
from flask.ctx import AppContext
from freezegun import freeze_time
from superset.extensions.metastore_cache import SupersetMetastoreCache
from superset.key_value.exceptions import KeyValueCodecEncodeException
from superset.key_value.types import (
JsonKeyValueCodec,
@@ -32,9 +33,6 @@ from superset.key_value.types import (
PickleKeyValueCodec,
)
if TYPE_CHECKING:
from superset.extensions.metastore_cache import SupersetMetastoreCache
NAMESPACE = UUID("ee173d1b-ccf3-40aa-941c-985c15224496")
FIRST_KEY = "foo"
@@ -47,8 +45,6 @@ SECOND_VALUE = "qwerty"
@pytest.fixture
def cache() -> SupersetMetastoreCache:
from superset.extensions.metastore_cache import SupersetMetastoreCache
return SupersetMetastoreCache(
namespace=NAMESPACE,
default_timeout=600,

View File

@@ -273,9 +273,9 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
"error_type": SupersetErrorType.OAUTH2_REDIRECT,
"level": ErrorLevel.WARNING,
"extra": {
"url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3ASYSADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vZXhhbXBsZS5jb20vYXBpL3YxL2RhdGFiYXNlL29hdXRoMi8iLCJ0YWJfaWQiOiJmYjExZjUyOC02ZWJhLTRhOGEtODM3ZS02YjBkMzllZTkxODcifQ%252Ec_m_35xwwSrLgCXwV4aKO1928flOEFQIqqg9ctiXjcM&redirect_uri=http%3A%2F%2Fexample.com%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id&prompt=consent",
"url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3ASYSADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vbG9jYWxob3N0L2FwaS92MS9kYXRhYmFzZS9vYXV0aDIvIiwidGFiX2lkIjoiZmIxMWY1MjgtNmViYS00YThhLTgzN2UtNmIwZDM5ZWU5MTg3In0%252E7nLkei6-V8sVk_Pgm8cFhk0tnKRKayRE1Vc7RxuM9mw&redirect_uri=http%3A%2F%2Flocalhost%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id&prompt=consent",
"tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
"redirect_uri": "http://localhost/api/v1/database/oauth2/",
},
}
],

View File

@@ -20,6 +20,7 @@ from typing import Optional
from unittest.mock import Mock
import pytest
import sqlglot
import sqlparse
from pytest_mock import MockerFixture
from sqlalchemy import text
@@ -41,6 +42,8 @@ from superset.sql_parse import (
insert_rls_in_predicate,
KustoKQLStatement,
ParsedQuery,
RLSAsPredicate,
RLSAsSubquery,
sanitize_clause,
split_kql,
SQLScript,
@@ -119,8 +122,9 @@ def test_extract_tables_subselect() -> None:
"""
Test that tables inside subselects are parsed correctly.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT sub.*
FROM (
SELECT *
@@ -129,10 +133,13 @@ FROM (
) sub, s2.t2
WHERE sub.resolution = 'NONE'
"""
) == {Table("t1", "s1"), Table("t2", "s2")}
)
== {Table("t1", "s1"), Table("t2", "s2")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT sub.*
FROM (
SELECT *
@@ -141,10 +148,13 @@ FROM (
) sub
WHERE sub.resolution = 'NONE'
"""
) == {Table("t1", "s1")}
)
== {Table("t1", "s1")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT * FROM t1
WHERE s11 > ANY (
SELECT COUNT(*) /* no hint */ FROM t2
@@ -156,7 +166,9 @@ WHERE s11 > ANY (
)
)
"""
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
)
== {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
)
def test_extract_tables_select_in_expression() -> None:
@@ -227,24 +239,30 @@ def test_extract_tables_select_array() -> None:
"""
Test that queries selecting arrays work as expected.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10
"""
) == {Table("t1")}
)
== {Table("t1")}
)
def test_extract_tables_select_if() -> None:
"""
Test that queries with an ``IF`` work as expected.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10
"""
) == {Table("t1")}
)
== {Table("t1")}
)
def test_extract_tables_with_catalog() -> None:
@@ -312,29 +330,38 @@ def test_extract_tables_where_subquery() -> None:
"""
Test that tables in a ``WHERE`` subquery are parsed correctly.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT name
FROM t1
WHERE regionkey = (SELECT max(regionkey) FROM t2)
"""
) == {Table("t1"), Table("t2")}
)
== {Table("t1"), Table("t2")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT name
FROM t1
WHERE regionkey IN (SELECT regionkey FROM t2)
"""
) == {Table("t1"), Table("t2")}
)
== {Table("t1"), Table("t2")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT name
FROM t1
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
"""
) == {Table("t1"), Table("t2")}
)
== {Table("t1"), Table("t2")}
)
def test_extract_tables_describe() -> None:
@@ -348,12 +375,15 @@ def test_extract_tables_show_partitions() -> None:
"""
Test ``SHOW PARTITIONS``.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC
"""
) == {Table("orders")}
)
== {Table("orders")}
)
def test_extract_tables_join() -> None:
@@ -365,8 +395,9 @@ def test_extract_tables_join() -> None:
Table("t2"),
}
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT a.date, b.name
FROM left_table a
JOIN (
@@ -377,10 +408,13 @@ JOIN (
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
)
== {Table("left_table"), Table("right_table")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT a.date, b.name
FROM left_table a
LEFT INNER JOIN (
@@ -391,10 +425,13 @@ LEFT INNER JOIN (
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
)
== {Table("left_table"), Table("right_table")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT a.date, b.name
FROM left_table a
RIGHT OUTER JOIN (
@@ -405,10 +442,13 @@ RIGHT OUTER JOIN (
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
)
== {Table("left_table"), Table("right_table")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT a.date, b.name
FROM left_table a
FULL OUTER JOIN (
@@ -419,15 +459,18 @@ FULL OUTER JOIN (
) b
ON a.date = b.date
"""
) == {Table("left_table"), Table("right_table")}
)
== {Table("left_table"), Table("right_table")}
)
def test_extract_tables_semi_join() -> None:
"""
Test ``LEFT SEMI JOIN``.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT a.date, b.name
FROM left_table a
LEFT SEMI JOIN (
@@ -438,15 +481,18 @@ LEFT SEMI JOIN (
) b
ON a.data = b.date
"""
) == {Table("left_table"), Table("right_table")}
)
== {Table("left_table"), Table("right_table")}
)
def test_extract_tables_combinations() -> None:
"""
Test a complex case with nested queries.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT * FROM t1
WHERE s11 > ANY (
SELECT * FROM t1 UNION ALL SELECT * FROM (
@@ -460,10 +506,13 @@ WHERE s11 > ANY (
)
)
"""
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
)
== {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT * FROM (
SELECT * FROM (
SELECT * FROM (
@@ -472,45 +521,56 @@ SELECT * FROM (
) AS S2
) AS S3
"""
) == {Table("EmployeeS")}
)
== {Table("EmployeeS")}
)
def test_extract_tables_with() -> None:
"""
Test ``WITH``.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM t2),
z AS (SELECT b AS c FROM t3)
SELECT c FROM z
"""
) == {Table("t1"), Table("t2"), Table("t3")}
)
== {Table("t1"), Table("t2"), Table("t3")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
WITH
x AS (SELECT a FROM t1),
y AS (SELECT a AS b FROM x),
z AS (SELECT b AS c FROM y)
SELECT c FROM z
"""
) == {Table("t1")}
)
== {Table("t1")}
)
def test_extract_tables_reusing_aliases() -> None:
"""
Test that the parser follows aliases.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
with q1 as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from q1) a
"""
) == {Table("src")}
)
== {Table("src")}
)
# weird query with circular dependency
assert (
@@ -547,8 +607,9 @@ def test_extract_tables_complex() -> None:
"""
Test a few complex queries.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT sum(m_examples) AS "sum__m_example"
FROM (
SELECT
@@ -569,23 +630,29 @@ FROM (
ORDER BY "sum__m_example" DESC
LIMIT 10;
"""
) == {
Table("my_l_table"),
Table("my_b_table"),
Table("my_t_table"),
Table("inner_table"),
}
)
== {
Table("my_l_table"),
Table("my_b_table"),
Table("my_t_table"),
Table("inner_table"),
}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT *
FROM table_a AS a, table_b AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
"""
) == {Table("table_a"), Table("table_b"), Table("table_c")}
)
== {Table("table_a"), Table("table_b"), Table("table_c")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT somecol AS somecol
FROM (
WITH bla AS (
@@ -629,51 +696,63 @@ FROM (
LIMIT 50000
)
"""
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
)
== {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
)
def test_extract_tables_mixed_from_clause() -> None:
"""
Test that the parser handles a ``FROM`` clause with table and subselect.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
SELECT *
FROM table_a AS a, (select * from table_b) AS b, table_c as c
WHERE a.id = b.id and b.id = c.id
"""
) == {Table("table_a"), Table("table_b"), Table("table_c")}
)
== {Table("table_a"), Table("table_b"), Table("table_c")}
)
def test_extract_tables_nested_select() -> None:
"""
Test that the parser handles selects inside functions.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
""",
"mysql",
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
"mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
assert extract_tables(
"""
assert (
extract_tables(
"""
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
""",
"mysql",
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
"mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
def test_extract_tables_complex_cte_with_prefix() -> None:
"""
Test that the parser handles CTEs with prefixes.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
AS (
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
@@ -685,21 +764,26 @@ FROM CTE__test
GROUP BY SalesYear, SalesPersonID
ORDER BY SalesPersonID, SalesYear;
"""
) == {Table("SalesOrderHeader")}
)
== {Table("SalesOrderHeader")}
)
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
"""
Test that aliases that are keywords are parsed correctly.
"""
assert extract_tables(
"""
assert (
extract_tables(
"""
WITH
f AS (SELECT * FROM foo),
match AS (SELECT * FROM f)
SELECT * FROM match
"""
) == {Table("foo")}
)
== {Table("foo")}
)
def test_update() -> None:
@@ -1841,7 +1925,7 @@ def test_sqlquery() -> None:
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
assert len(script.statements) == 2
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
assert script.format() == "SELECT\n 1;\nSELECT\n 2;"
assert script.statements[0].format() == "SELECT\n 1"
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
@@ -2058,3 +2142,120 @@ on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
]
@pytest.mark.parametrize(
"sql,rules,expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table"): "id = 42"},
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table"): "id = 42"},
(
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t "
"WHERE bar = 'baz'"
),
),
(
"SELECT t.foo FROM schema1.some_table AS t",
{Table("some_table", "schema1"): "id = 42"},
"SELECT t.foo FROM (SELECT * FROM schema1.some_table WHERE id = 42) AS t",
),
(
"SELECT t.foo FROM schema1.some_table AS t",
{Table("some_table", "schema2"): "id = 42"},
"SELECT t.foo FROM schema1.some_table AS t",
),
(
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"SELECT t.foo FROM (SELECT * FROM catalog1.schema1.some_table WHERE id = 42) AS t",
),
(
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog2"): "id = 42"},
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
),
],
)
def test_RLSAsSubquery(sql: str, rules: dict[Table, str], expected: str) -> None:
"""
Test the `RLSAsSubquery` transformer.
"""
statement = sqlglot.parse_one(sql)
transformer = RLSAsSubquery(rules)
assert str(statement.transform(transformer)) == expected
@pytest.mark.parametrize(
"sql,rules,expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table"): "id = 42"},
"SELECT t.foo FROM some_table AS t WHERE id = 42",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table"): "id = 42"},
"SELECT t.foo FROM some_table AS t WHERE id = 42 AND bar = 'baz'",
),
],
)
def test_RLSAsPredicate(sql: str, rules: dict[Table, str], expected: str) -> None:
"""
Test the `RLSAsPredicate` transformer.
"""
statement = sqlglot.parse_one(sql)
transformer = RLSAsPredicate(rules)
assert str(statement.transform(transformer)) == expected
@pytest.mark.parametrize(
"sql,engine,limit,force,expected",
[
(
"SELECT TOP 10 * FROM Customers",
"teradatasql",
5,
False,
"SELECT\nTOP 5\n *\nFROM Customers",
),
(
"SELECT TOP 10 * FROM Customers",
"teradatasql",
15,
False,
"SELECT\nTOP 10\n *\nFROM Customers",
),
(
"SELECT TOP 10 * FROM Customers",
"teradatasql",
15,
True,
"SELECT\nTOP 15\n *\nFROM Customers",
),
(
"SELECT TOP 10 * FROM Customers",
"mssql",
15,
True,
"SELECT\nTOP 15\n *\nFROM Customers",
),
],
)
def test_apply_limit(
sql: str,
engine: str,
limit: int,
force: bool,
expected: str,
) -> None:
"""
Test the `apply_limit` function.
"""
assert SQLStatement(sql, engine).apply_limit(limit, force).format() == expected