Compare commits

...

5 Commits

Author SHA1 Message Date
Beto Dealmeida
503933756e fix: show only filterable columns on filter dropdown 2025-05-01 17:20:02 -04:00
Beto Dealmeida
60261a5dc6 Fix docstring 2025-05-01 12:33:44 -04:00
Beto Dealmeida
456512c508 Fix test 2025-05-01 11:56:13 -04:00
Beto Dealmeida
c1d9b06649 Remove old method 2025-05-01 09:49:50 -04:00
Beto Dealmeida
7a64a82cd9 fix: improve function detection 2025-04-30 16:58:11 -04:00
7 changed files with 144 additions and 37 deletions

View File

@@ -99,6 +99,7 @@ export function ColumnSelect({
'columns.column_name',
'columns.is_dttm',
'columns.type_generic',
'columns.filterable',
],
})}`,
})

View File

@@ -20,11 +20,38 @@ import { Filter, NativeFilterType } from '@superset-ui/core';
import { render, screen, userEvent } from 'spec/helpers/testing-library';
import { FormInstance } from 'src/components';
import getControlItemsMap, { ControlItemsProps } from './getControlItemsMap';
import { getControlItems, setNativeFilterFieldValues } from './utils';
import {
getControlItems,
setNativeFilterFieldValues,
doesColumnMatchFilterType,
} from './utils';
jest.mock('./utils', () => ({
getControlItems: jest.fn(),
setNativeFilterFieldValues: jest.fn(),
doesColumnMatchFilterType: jest.fn(),
}));
// Mock ColumnSelect to test filterValues logic
jest.mock('./ColumnSelect', () => ({
ColumnSelect: ({
filterValues,
}: {
filterValues: (column: any) => boolean;
}) => {
const columns = [
{ name: 'col1', filterable: true },
{ name: 'col2', filterable: false },
{ name: 'col3', filterable: true },
];
return (
<>
{columns.filter(filterValues).map(column => (
<div key={column.name}>{column.name}</div>
))}
</>
);
},
}));
const formMock: FormInstance = {
@@ -62,7 +89,7 @@ const filterMock: Filter = {
description: '',
};
const createProps: () => ControlItemsProps = () => ({
const createProps = (): ControlItemsProps => ({
expanded: false,
datasetId: 1,
disabled: false,
@@ -179,3 +206,42 @@ test('Clicking on checkbox when resetConfig:false', () => {
expect(props.forceUpdate).toHaveBeenCalled();
expect(setNativeFilterFieldValues).not.toHaveBeenCalled();
});
describe('ColumnSelect filterValues behavior', () => {
beforeEach(() => {
(getControlItems as jest.Mock).mockReturnValue([
{
name: 'groupby',
config: { label: 'Column', multiple: false, required: false },
},
]);
});
test('only renders filterable columns when doesColumnMatchFilterType returns true', () => {
(doesColumnMatchFilterType as jest.Mock).mockReturnValue(true);
const props = {
...createProps(),
formFilter: { filterType: 'filterType' },
};
const element = getControlItemsMap(props).mainControlItems.groupby
.element as React.ReactElement;
render(element);
expect(screen.getByText('col1')).toBeInTheDocument();
expect(screen.getByText('col3')).toBeInTheDocument();
expect(screen.queryByText('col2')).not.toBeInTheDocument();
});
test('renders no columns when doesColumnMatchFilterType returns false', () => {
(doesColumnMatchFilterType as jest.Mock).mockReturnValue(false);
const props = {
...createProps(),
formFilter: { filterType: 'filterType' },
};
const element = getControlItemsMap(props).mainControlItems.groupby
.element as React.ReactElement;
render(element);
expect(screen.queryByText('col1')).not.toBeInTheDocument();
expect(screen.queryByText('col3')).not.toBeInTheDocument();
expect(screen.queryByText('col2')).not.toBeInTheDocument();
});
});

View File

@@ -131,7 +131,10 @@ export default function getControlItemsMap({
filterId={filterId}
datasetId={datasetId}
filterValues={column =>
doesColumnMatchFilterType(formFilter?.filterType || '', column)
doesColumnMatchFilterType(
formFilter?.filterType || '',
column,
) && column.filterable
}
onChange={() => {
// We need reset default value when column changed

View File

@@ -32,6 +32,7 @@ from deprecation import deprecated
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError
from sqlglot.expressions import Func
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
@@ -453,6 +454,23 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
return SQLStatement(sql, self.engine, optimized)
def check_functions_present(self, functions: set[str]) -> bool:
"""
Check if any of the given functions are present in the script.
:param functions: List of functions to check for
:return: True if any of the functions are present
"""
present = {
(
function.sql_name()
if function.sql_name() != "ANONYMOUS"
else function.name.upper()
)
for function in self._parsed.find_all(Func)
}
return any(function.upper() in present for function in functions)
class KQLSplitState(enum.Enum):
"""
@@ -619,6 +637,16 @@ class KustoKQLStatement(BaseSQLStatement[str]):
"""
return KustoKQLStatement(self._sql, self.engine, self._parsed)
def check_functions_present(self, functions: set[str]) -> bool:
"""
Check if any of the given functions are present in the script.
:param functions: List of functions to check for
:return: True if any of the functions are present
"""
logger.warning("Kusto KQL doesn't support checking for functions present.")
return True
class SQLScript:
"""
@@ -684,6 +712,18 @@ class SQLScript:
return script
def check_functions_present(self, functions: set[str]) -> bool:
"""
Check if any of the given functions are present in the script.
:param functions: List of functions to check for
:return: True if any of the functions are present
"""
return any(
statement.check_functions_present(functions)
for statement in self.statements
)
def extract_tables_from_statement(
statement: exp.Expression,

View File

@@ -31,7 +31,6 @@ from sqlalchemy import and_
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
Function,
Identifier,
IdentifierList,
Parenthesis,
@@ -181,7 +180,7 @@ def check_sql_functions_exist(
:param function_list: The list of functions to search for
:param engine: The engine to use for parsing the SQL statement
"""
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
return SQLScript(sql, engine=engine).check_functions_present(function_list)
def strip_comments_from_sql(statement: str, engine: str = "base") -> str:
@@ -229,34 +228,6 @@ class ParsedQuery:
self._tables = self._extract_tables_from_sql()
return self._tables
def _check_functions_exist_in_token(
self, token: Token, functions: set[str]
) -> bool:
if (
isinstance(token, Function)
and token.get_name() is not None
and token.get_name().lower() in functions
):
return True
if hasattr(token, "tokens"):
for inner_token in token.tokens:
if self._check_functions_exist_in_token(inner_token, functions):
return True
return False
def check_functions_exist(self, functions: set[str]) -> bool:
"""
Check if the SQL statement contains any of the specified functions.
:param functions: A set of functions to search for
:return: True if the statement contains any of the specified functions
"""
for statement in self._parsed:
for token in statement.tokens:
if self._check_functions_exist_in_token(token, functions):
return True
return False
def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.

View File

@@ -580,10 +580,7 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
)
rv = test_client.post(uri, json={})
assert rv.status_code == 422
assert "error" in rv.json
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
assert "INCORRECT SQL" in rv.json.get("error")
assert rv.json["errors"][0]["error_type"] == "INVALID_SQL_ERROR"
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)

View File

@@ -1237,6 +1237,35 @@ def test_check_sql_functions_exist() -> None:
)
def test_check_sql_functions_exist_with_comments() -> None:
"""
Test sql functions are detected correctly with comments
"""
assert not (
check_sql_functions_exist(
"select a, b from version/**/", {"version"}, "postgresql"
)
)
assert check_sql_functions_exist("select version/**/()", {"version"}, "postgresql")
assert check_sql_functions_exist(
"select version from version/**/()", {"version"}, "postgresql"
)
assert check_sql_functions_exist(
"select 1, a.version from (select version from version/**/()) as a",
{"version"},
"postgresql",
)
assert check_sql_functions_exist(
"select 1, a.version from (select version/**/()) as a",
{"version"},
"postgresql",
)
def test_sanitize_clause_valid():
# regular clauses
assert sanitize_clause("col = 1") == "col = 1"