mirror of
https://github.com/apache/superset.git
synced 2026-04-30 13:34:20 +00:00
Compare commits
5 Commits
feat/toolt
...
improve-fu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
503933756e | ||
|
|
60261a5dc6 | ||
|
|
456512c508 | ||
|
|
c1d9b06649 | ||
|
|
7a64a82cd9 |
@@ -99,6 +99,7 @@ export function ColumnSelect({
|
|||||||
'columns.column_name',
|
'columns.column_name',
|
||||||
'columns.is_dttm',
|
'columns.is_dttm',
|
||||||
'columns.type_generic',
|
'columns.type_generic',
|
||||||
|
'columns.filterable',
|
||||||
],
|
],
|
||||||
})}`,
|
})}`,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -20,11 +20,38 @@ import { Filter, NativeFilterType } from '@superset-ui/core';
|
|||||||
import { render, screen, userEvent } from 'spec/helpers/testing-library';
|
import { render, screen, userEvent } from 'spec/helpers/testing-library';
|
||||||
import { FormInstance } from 'src/components';
|
import { FormInstance } from 'src/components';
|
||||||
import getControlItemsMap, { ControlItemsProps } from './getControlItemsMap';
|
import getControlItemsMap, { ControlItemsProps } from './getControlItemsMap';
|
||||||
import { getControlItems, setNativeFilterFieldValues } from './utils';
|
import {
|
||||||
|
getControlItems,
|
||||||
|
setNativeFilterFieldValues,
|
||||||
|
doesColumnMatchFilterType,
|
||||||
|
} from './utils';
|
||||||
|
|
||||||
jest.mock('./utils', () => ({
|
jest.mock('./utils', () => ({
|
||||||
getControlItems: jest.fn(),
|
getControlItems: jest.fn(),
|
||||||
setNativeFilterFieldValues: jest.fn(),
|
setNativeFilterFieldValues: jest.fn(),
|
||||||
|
doesColumnMatchFilterType: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock ColumnSelect to test filterValues logic
|
||||||
|
jest.mock('./ColumnSelect', () => ({
|
||||||
|
ColumnSelect: ({
|
||||||
|
filterValues,
|
||||||
|
}: {
|
||||||
|
filterValues: (column: any) => boolean;
|
||||||
|
}) => {
|
||||||
|
const columns = [
|
||||||
|
{ name: 'col1', filterable: true },
|
||||||
|
{ name: 'col2', filterable: false },
|
||||||
|
{ name: 'col3', filterable: true },
|
||||||
|
];
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{columns.filter(filterValues).map(column => (
|
||||||
|
<div key={column.name}>{column.name}</div>
|
||||||
|
))}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
const formMock: FormInstance = {
|
const formMock: FormInstance = {
|
||||||
@@ -62,7 +89,7 @@ const filterMock: Filter = {
|
|||||||
description: '',
|
description: '',
|
||||||
};
|
};
|
||||||
|
|
||||||
const createProps: () => ControlItemsProps = () => ({
|
const createProps = (): ControlItemsProps => ({
|
||||||
expanded: false,
|
expanded: false,
|
||||||
datasetId: 1,
|
datasetId: 1,
|
||||||
disabled: false,
|
disabled: false,
|
||||||
@@ -179,3 +206,42 @@ test('Clicking on checkbox when resetConfig:false', () => {
|
|||||||
expect(props.forceUpdate).toHaveBeenCalled();
|
expect(props.forceUpdate).toHaveBeenCalled();
|
||||||
expect(setNativeFilterFieldValues).not.toHaveBeenCalled();
|
expect(setNativeFilterFieldValues).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('ColumnSelect filterValues behavior', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
(getControlItems as jest.Mock).mockReturnValue([
|
||||||
|
{
|
||||||
|
name: 'groupby',
|
||||||
|
config: { label: 'Column', multiple: false, required: false },
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('only renders filterable columns when doesColumnMatchFilterType returns true', () => {
|
||||||
|
(doesColumnMatchFilterType as jest.Mock).mockReturnValue(true);
|
||||||
|
const props = {
|
||||||
|
...createProps(),
|
||||||
|
formFilter: { filterType: 'filterType' },
|
||||||
|
};
|
||||||
|
const element = getControlItemsMap(props).mainControlItems.groupby
|
||||||
|
.element as React.ReactElement;
|
||||||
|
render(element);
|
||||||
|
expect(screen.getByText('col1')).toBeInTheDocument();
|
||||||
|
expect(screen.getByText('col3')).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText('col2')).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test('renders no columns when doesColumnMatchFilterType returns false', () => {
|
||||||
|
(doesColumnMatchFilterType as jest.Mock).mockReturnValue(false);
|
||||||
|
const props = {
|
||||||
|
...createProps(),
|
||||||
|
formFilter: { filterType: 'filterType' },
|
||||||
|
};
|
||||||
|
const element = getControlItemsMap(props).mainControlItems.groupby
|
||||||
|
.element as React.ReactElement;
|
||||||
|
render(element);
|
||||||
|
expect(screen.queryByText('col1')).not.toBeInTheDocument();
|
||||||
|
expect(screen.queryByText('col3')).not.toBeInTheDocument();
|
||||||
|
expect(screen.queryByText('col2')).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
@@ -131,7 +131,10 @@ export default function getControlItemsMap({
|
|||||||
filterId={filterId}
|
filterId={filterId}
|
||||||
datasetId={datasetId}
|
datasetId={datasetId}
|
||||||
filterValues={column =>
|
filterValues={column =>
|
||||||
doesColumnMatchFilterType(formFilter?.filterType || '', column)
|
doesColumnMatchFilterType(
|
||||||
|
formFilter?.filterType || '',
|
||||||
|
column,
|
||||||
|
) && column.filterable
|
||||||
}
|
}
|
||||||
onChange={() => {
|
onChange={() => {
|
||||||
// We need reset default value when column changed
|
// We need reset default value when column changed
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from deprecation import deprecated
|
|||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||||
from sqlglot.errors import ParseError
|
from sqlglot.errors import ParseError
|
||||||
|
from sqlglot.expressions import Func
|
||||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||||
|
|
||||||
@@ -453,6 +454,23 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
|||||||
|
|
||||||
return SQLStatement(sql, self.engine, optimized)
|
return SQLStatement(sql, self.engine, optimized)
|
||||||
|
|
||||||
|
def check_functions_present(self, functions: set[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if any of the given functions are present in the script.
|
||||||
|
|
||||||
|
:param functions: List of functions to check for
|
||||||
|
:return: True if any of the functions are present
|
||||||
|
"""
|
||||||
|
present = {
|
||||||
|
(
|
||||||
|
function.sql_name()
|
||||||
|
if function.sql_name() != "ANONYMOUS"
|
||||||
|
else function.name.upper()
|
||||||
|
)
|
||||||
|
for function in self._parsed.find_all(Func)
|
||||||
|
}
|
||||||
|
return any(function.upper() in present for function in functions)
|
||||||
|
|
||||||
|
|
||||||
class KQLSplitState(enum.Enum):
|
class KQLSplitState(enum.Enum):
|
||||||
"""
|
"""
|
||||||
@@ -619,6 +637,16 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
|||||||
"""
|
"""
|
||||||
return KustoKQLStatement(self._sql, self.engine, self._parsed)
|
return KustoKQLStatement(self._sql, self.engine, self._parsed)
|
||||||
|
|
||||||
|
def check_functions_present(self, functions: set[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if any of the given functions are present in the script.
|
||||||
|
|
||||||
|
:param functions: List of functions to check for
|
||||||
|
:return: True if any of the functions are present
|
||||||
|
"""
|
||||||
|
logger.warning("Kusto KQL doesn't support checking for functions present.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class SQLScript:
|
class SQLScript:
|
||||||
"""
|
"""
|
||||||
@@ -684,6 +712,18 @@ class SQLScript:
|
|||||||
|
|
||||||
return script
|
return script
|
||||||
|
|
||||||
|
def check_functions_present(self, functions: set[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if any of the given functions are present in the script.
|
||||||
|
|
||||||
|
:param functions: List of functions to check for
|
||||||
|
:return: True if any of the functions are present
|
||||||
|
"""
|
||||||
|
return any(
|
||||||
|
statement.check_functions_present(functions)
|
||||||
|
for statement in self.statements
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_tables_from_statement(
|
def extract_tables_from_statement(
|
||||||
statement: exp.Expression,
|
statement: exp.Expression,
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from sqlalchemy import and_
|
|||||||
from sqlparse import keywords
|
from sqlparse import keywords
|
||||||
from sqlparse.lexer import Lexer
|
from sqlparse.lexer import Lexer
|
||||||
from sqlparse.sql import (
|
from sqlparse.sql import (
|
||||||
Function,
|
|
||||||
Identifier,
|
Identifier,
|
||||||
IdentifierList,
|
IdentifierList,
|
||||||
Parenthesis,
|
Parenthesis,
|
||||||
@@ -181,7 +180,7 @@ def check_sql_functions_exist(
|
|||||||
:param function_list: The list of functions to search for
|
:param function_list: The list of functions to search for
|
||||||
:param engine: The engine to use for parsing the SQL statement
|
:param engine: The engine to use for parsing the SQL statement
|
||||||
"""
|
"""
|
||||||
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
|
return SQLScript(sql, engine=engine).check_functions_present(function_list)
|
||||||
|
|
||||||
|
|
||||||
def strip_comments_from_sql(statement: str, engine: str = "base") -> str:
|
def strip_comments_from_sql(statement: str, engine: str = "base") -> str:
|
||||||
@@ -229,34 +228,6 @@ class ParsedQuery:
|
|||||||
self._tables = self._extract_tables_from_sql()
|
self._tables = self._extract_tables_from_sql()
|
||||||
return self._tables
|
return self._tables
|
||||||
|
|
||||||
def _check_functions_exist_in_token(
|
|
||||||
self, token: Token, functions: set[str]
|
|
||||||
) -> bool:
|
|
||||||
if (
|
|
||||||
isinstance(token, Function)
|
|
||||||
and token.get_name() is not None
|
|
||||||
and token.get_name().lower() in functions
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
if hasattr(token, "tokens"):
|
|
||||||
for inner_token in token.tokens:
|
|
||||||
if self._check_functions_exist_in_token(inner_token, functions):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def check_functions_exist(self, functions: set[str]) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the SQL statement contains any of the specified functions.
|
|
||||||
|
|
||||||
:param functions: A set of functions to search for
|
|
||||||
:return: True if the statement contains any of the specified functions
|
|
||||||
"""
|
|
||||||
for statement in self._parsed:
|
|
||||||
for token in statement.tokens:
|
|
||||||
if self._check_functions_exist_in_token(token, functions):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _extract_tables_from_sql(self) -> set[Table]:
|
def _extract_tables_from_sql(self) -> set[Table]:
|
||||||
"""
|
"""
|
||||||
Extract all table references in a query.
|
Extract all table references in a query.
|
||||||
|
|||||||
@@ -580,10 +580,7 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
|
|||||||
)
|
)
|
||||||
rv = test_client.post(uri, json={})
|
rv = test_client.post(uri, json={})
|
||||||
assert rv.status_code == 422
|
assert rv.status_code == 422
|
||||||
|
assert rv.json["errors"][0]["error_type"] == "INVALID_SQL_ERROR"
|
||||||
assert "error" in rv.json
|
|
||||||
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
|
|
||||||
assert "INCORRECT SQL" in rv.json.get("error")
|
|
||||||
|
|
||||||
|
|
||||||
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
|
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
|
||||||
|
|||||||
@@ -1237,6 +1237,35 @@ def test_check_sql_functions_exist() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_sql_functions_exist_with_comments() -> None:
|
||||||
|
"""
|
||||||
|
Test sql functions are detected correctly with comments
|
||||||
|
"""
|
||||||
|
assert not (
|
||||||
|
check_sql_functions_exist(
|
||||||
|
"select a, b from version/**/", {"version"}, "postgresql"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert check_sql_functions_exist("select version/**/()", {"version"}, "postgresql")
|
||||||
|
|
||||||
|
assert check_sql_functions_exist(
|
||||||
|
"select version from version/**/()", {"version"}, "postgresql"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert check_sql_functions_exist(
|
||||||
|
"select 1, a.version from (select version from version/**/()) as a",
|
||||||
|
{"version"},
|
||||||
|
"postgresql",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert check_sql_functions_exist(
|
||||||
|
"select 1, a.version from (select version/**/()) as a",
|
||||||
|
{"version"},
|
||||||
|
"postgresql",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_clause_valid():
|
def test_sanitize_clause_valid():
|
||||||
# regular clauses
|
# regular clauses
|
||||||
assert sanitize_clause("col = 1") == "col = 1"
|
assert sanitize_clause("col = 1") == "col = 1"
|
||||||
|
|||||||
Reference in New Issue
Block a user