mirror of
https://github.com/apache/superset.git
synced 2026-04-28 20:44:24 +00:00
Compare commits
5 Commits
fix_sqllab
...
improve-fu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
503933756e | ||
|
|
60261a5dc6 | ||
|
|
456512c508 | ||
|
|
c1d9b06649 | ||
|
|
7a64a82cd9 |
@@ -99,6 +99,7 @@ export function ColumnSelect({
|
||||
'columns.column_name',
|
||||
'columns.is_dttm',
|
||||
'columns.type_generic',
|
||||
'columns.filterable',
|
||||
],
|
||||
})}`,
|
||||
})
|
||||
|
||||
@@ -20,11 +20,38 @@ import { Filter, NativeFilterType } from '@superset-ui/core';
|
||||
import { render, screen, userEvent } from 'spec/helpers/testing-library';
|
||||
import { FormInstance } from 'src/components';
|
||||
import getControlItemsMap, { ControlItemsProps } from './getControlItemsMap';
|
||||
import { getControlItems, setNativeFilterFieldValues } from './utils';
|
||||
import {
|
||||
getControlItems,
|
||||
setNativeFilterFieldValues,
|
||||
doesColumnMatchFilterType,
|
||||
} from './utils';
|
||||
|
||||
jest.mock('./utils', () => ({
|
||||
getControlItems: jest.fn(),
|
||||
setNativeFilterFieldValues: jest.fn(),
|
||||
doesColumnMatchFilterType: jest.fn(),
|
||||
}));
|
||||
|
||||
// Mock ColumnSelect to test filterValues logic
|
||||
jest.mock('./ColumnSelect', () => ({
|
||||
ColumnSelect: ({
|
||||
filterValues,
|
||||
}: {
|
||||
filterValues: (column: any) => boolean;
|
||||
}) => {
|
||||
const columns = [
|
||||
{ name: 'col1', filterable: true },
|
||||
{ name: 'col2', filterable: false },
|
||||
{ name: 'col3', filterable: true },
|
||||
];
|
||||
return (
|
||||
<>
|
||||
{columns.filter(filterValues).map(column => (
|
||||
<div key={column.name}>{column.name}</div>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
},
|
||||
}));
|
||||
|
||||
const formMock: FormInstance = {
|
||||
@@ -62,7 +89,7 @@ const filterMock: Filter = {
|
||||
description: '',
|
||||
};
|
||||
|
||||
const createProps: () => ControlItemsProps = () => ({
|
||||
const createProps = (): ControlItemsProps => ({
|
||||
expanded: false,
|
||||
datasetId: 1,
|
||||
disabled: false,
|
||||
@@ -179,3 +206,42 @@ test('Clicking on checkbox when resetConfig:false', () => {
|
||||
expect(props.forceUpdate).toHaveBeenCalled();
|
||||
expect(setNativeFilterFieldValues).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('ColumnSelect filterValues behavior', () => {
|
||||
beforeEach(() => {
|
||||
(getControlItems as jest.Mock).mockReturnValue([
|
||||
{
|
||||
name: 'groupby',
|
||||
config: { label: 'Column', multiple: false, required: false },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test('only renders filterable columns when doesColumnMatchFilterType returns true', () => {
|
||||
(doesColumnMatchFilterType as jest.Mock).mockReturnValue(true);
|
||||
const props = {
|
||||
...createProps(),
|
||||
formFilter: { filterType: 'filterType' },
|
||||
};
|
||||
const element = getControlItemsMap(props).mainControlItems.groupby
|
||||
.element as React.ReactElement;
|
||||
render(element);
|
||||
expect(screen.getByText('col1')).toBeInTheDocument();
|
||||
expect(screen.getByText('col3')).toBeInTheDocument();
|
||||
expect(screen.queryByText('col2')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('renders no columns when doesColumnMatchFilterType returns false', () => {
|
||||
(doesColumnMatchFilterType as jest.Mock).mockReturnValue(false);
|
||||
const props = {
|
||||
...createProps(),
|
||||
formFilter: { filterType: 'filterType' },
|
||||
};
|
||||
const element = getControlItemsMap(props).mainControlItems.groupby
|
||||
.element as React.ReactElement;
|
||||
render(element);
|
||||
expect(screen.queryByText('col1')).not.toBeInTheDocument();
|
||||
expect(screen.queryByText('col3')).not.toBeInTheDocument();
|
||||
expect(screen.queryByText('col2')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -131,7 +131,10 @@ export default function getControlItemsMap({
|
||||
filterId={filterId}
|
||||
datasetId={datasetId}
|
||||
filterValues={column =>
|
||||
doesColumnMatchFilterType(formFilter?.filterType || '', column)
|
||||
doesColumnMatchFilterType(
|
||||
formFilter?.filterType || '',
|
||||
column,
|
||||
) && column.filterable
|
||||
}
|
||||
onChange={() => {
|
||||
// We need reset default value when column changed
|
||||
|
||||
@@ -32,6 +32,7 @@ from deprecation import deprecated
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.expressions import Func
|
||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||
|
||||
@@ -453,6 +454,23 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
|
||||
return SQLStatement(sql, self.engine, optimized)
|
||||
|
||||
def check_functions_present(self, functions: set[str]) -> bool:
|
||||
"""
|
||||
Check if any of the given functions are present in the script.
|
||||
|
||||
:param functions: List of functions to check for
|
||||
:return: True if any of the functions are present
|
||||
"""
|
||||
present = {
|
||||
(
|
||||
function.sql_name()
|
||||
if function.sql_name() != "ANONYMOUS"
|
||||
else function.name.upper()
|
||||
)
|
||||
for function in self._parsed.find_all(Func)
|
||||
}
|
||||
return any(function.upper() in present for function in functions)
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
@@ -619,6 +637,16 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
"""
|
||||
return KustoKQLStatement(self._sql, self.engine, self._parsed)
|
||||
|
||||
def check_functions_present(self, functions: set[str]) -> bool:
|
||||
"""
|
||||
Check if any of the given functions are present in the script.
|
||||
|
||||
:param functions: List of functions to check for
|
||||
:return: True if any of the functions are present
|
||||
"""
|
||||
logger.warning("Kusto KQL doesn't support checking for functions present.")
|
||||
return True
|
||||
|
||||
|
||||
class SQLScript:
|
||||
"""
|
||||
@@ -684,6 +712,18 @@ class SQLScript:
|
||||
|
||||
return script
|
||||
|
||||
def check_functions_present(self, functions: set[str]) -> bool:
|
||||
"""
|
||||
Check if any of the given functions are present in the script.
|
||||
|
||||
:param functions: List of functions to check for
|
||||
:return: True if any of the functions are present
|
||||
"""
|
||||
return any(
|
||||
statement.check_functions_present(functions)
|
||||
for statement in self.statements
|
||||
)
|
||||
|
||||
|
||||
def extract_tables_from_statement(
|
||||
statement: exp.Expression,
|
||||
|
||||
@@ -31,7 +31,6 @@ from sqlalchemy import and_
|
||||
from sqlparse import keywords
|
||||
from sqlparse.lexer import Lexer
|
||||
from sqlparse.sql import (
|
||||
Function,
|
||||
Identifier,
|
||||
IdentifierList,
|
||||
Parenthesis,
|
||||
@@ -181,7 +180,7 @@ def check_sql_functions_exist(
|
||||
:param function_list: The list of functions to search for
|
||||
:param engine: The engine to use for parsing the SQL statement
|
||||
"""
|
||||
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
|
||||
return SQLScript(sql, engine=engine).check_functions_present(function_list)
|
||||
|
||||
|
||||
def strip_comments_from_sql(statement: str, engine: str = "base") -> str:
|
||||
@@ -229,34 +228,6 @@ class ParsedQuery:
|
||||
self._tables = self._extract_tables_from_sql()
|
||||
return self._tables
|
||||
|
||||
def _check_functions_exist_in_token(
|
||||
self, token: Token, functions: set[str]
|
||||
) -> bool:
|
||||
if (
|
||||
isinstance(token, Function)
|
||||
and token.get_name() is not None
|
||||
and token.get_name().lower() in functions
|
||||
):
|
||||
return True
|
||||
if hasattr(token, "tokens"):
|
||||
for inner_token in token.tokens:
|
||||
if self._check_functions_exist_in_token(inner_token, functions):
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_functions_exist(self, functions: set[str]) -> bool:
|
||||
"""
|
||||
Check if the SQL statement contains any of the specified functions.
|
||||
|
||||
:param functions: A set of functions to search for
|
||||
:return: True if the statement contains any of the specified functions
|
||||
"""
|
||||
for statement in self._parsed:
|
||||
for token in statement.tokens:
|
||||
if self._check_functions_exist_in_token(token, functions):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_tables_from_sql(self) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in a query.
|
||||
|
||||
@@ -580,10 +580,7 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
|
||||
)
|
||||
rv = test_client.post(uri, json={})
|
||||
assert rv.status_code == 422
|
||||
|
||||
assert "error" in rv.json
|
||||
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
|
||||
assert "INCORRECT SQL" in rv.json.get("error")
|
||||
assert rv.json["errors"][0]["error_type"] == "INVALID_SQL_ERROR"
|
||||
|
||||
|
||||
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
|
||||
|
||||
@@ -1237,6 +1237,35 @@ def test_check_sql_functions_exist() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_check_sql_functions_exist_with_comments() -> None:
|
||||
"""
|
||||
Test sql functions are detected correctly with comments
|
||||
"""
|
||||
assert not (
|
||||
check_sql_functions_exist(
|
||||
"select a, b from version/**/", {"version"}, "postgresql"
|
||||
)
|
||||
)
|
||||
|
||||
assert check_sql_functions_exist("select version/**/()", {"version"}, "postgresql")
|
||||
|
||||
assert check_sql_functions_exist(
|
||||
"select version from version/**/()", {"version"}, "postgresql"
|
||||
)
|
||||
|
||||
assert check_sql_functions_exist(
|
||||
"select 1, a.version from (select version from version/**/()) as a",
|
||||
{"version"},
|
||||
"postgresql",
|
||||
)
|
||||
|
||||
assert check_sql_functions_exist(
|
||||
"select 1, a.version from (select version/**/()) as a",
|
||||
{"version"},
|
||||
"postgresql",
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_clause_valid():
|
||||
# regular clauses
|
||||
assert sanitize_clause("col = 1") == "col = 1"
|
||||
|
||||
Reference in New Issue
Block a user