Compare commits

...

5 Commits

Author SHA1 Message Date
Beto Dealmeida
503933756e fix: show only filterable columns on filter dropdown 2025-05-01 17:20:02 -04:00
Beto Dealmeida
60261a5dc6 Fix docstring 2025-05-01 12:33:44 -04:00
Beto Dealmeida
456512c508 Fix test 2025-05-01 11:56:13 -04:00
Beto Dealmeida
c1d9b06649 Remove old method 2025-05-01 09:49:50 -04:00
Beto Dealmeida
7a64a82cd9 fix: improve function detection 2025-04-30 16:58:11 -04:00
7 changed files with 144 additions and 37 deletions

View File

@@ -99,6 +99,7 @@ export function ColumnSelect({
'columns.column_name', 'columns.column_name',
'columns.is_dttm', 'columns.is_dttm',
'columns.type_generic', 'columns.type_generic',
'columns.filterable',
], ],
})}`, })}`,
}) })

View File

@@ -20,11 +20,38 @@ import { Filter, NativeFilterType } from '@superset-ui/core';
import { render, screen, userEvent } from 'spec/helpers/testing-library'; import { render, screen, userEvent } from 'spec/helpers/testing-library';
import { FormInstance } from 'src/components'; import { FormInstance } from 'src/components';
import getControlItemsMap, { ControlItemsProps } from './getControlItemsMap'; import getControlItemsMap, { ControlItemsProps } from './getControlItemsMap';
import { getControlItems, setNativeFilterFieldValues } from './utils'; import {
getControlItems,
setNativeFilterFieldValues,
doesColumnMatchFilterType,
} from './utils';
jest.mock('./utils', () => ({ jest.mock('./utils', () => ({
getControlItems: jest.fn(), getControlItems: jest.fn(),
setNativeFilterFieldValues: jest.fn(), setNativeFilterFieldValues: jest.fn(),
doesColumnMatchFilterType: jest.fn(),
}));
// Mock ColumnSelect to test filterValues logic
jest.mock('./ColumnSelect', () => ({
ColumnSelect: ({
filterValues,
}: {
filterValues: (column: any) => boolean;
}) => {
const columns = [
{ name: 'col1', filterable: true },
{ name: 'col2', filterable: false },
{ name: 'col3', filterable: true },
];
return (
<>
{columns.filter(filterValues).map(column => (
<div key={column.name}>{column.name}</div>
))}
</>
);
},
})); }));
const formMock: FormInstance = { const formMock: FormInstance = {
@@ -62,7 +89,7 @@ const filterMock: Filter = {
description: '', description: '',
}; };
const createProps: () => ControlItemsProps = () => ({ const createProps = (): ControlItemsProps => ({
expanded: false, expanded: false,
datasetId: 1, datasetId: 1,
disabled: false, disabled: false,
@@ -179,3 +206,42 @@ test('Clicking on checkbox when resetConfig:false', () => {
expect(props.forceUpdate).toHaveBeenCalled(); expect(props.forceUpdate).toHaveBeenCalled();
expect(setNativeFilterFieldValues).not.toHaveBeenCalled(); expect(setNativeFilterFieldValues).not.toHaveBeenCalled();
}); });
describe('ColumnSelect filterValues behavior', () => {
beforeEach(() => {
(getControlItems as jest.Mock).mockReturnValue([
{
name: 'groupby',
config: { label: 'Column', multiple: false, required: false },
},
]);
});
test('only renders filterable columns when doesColumnMatchFilterType returns true', () => {
(doesColumnMatchFilterType as jest.Mock).mockReturnValue(true);
const props = {
...createProps(),
formFilter: { filterType: 'filterType' },
};
const element = getControlItemsMap(props).mainControlItems.groupby
.element as React.ReactElement;
render(element);
expect(screen.getByText('col1')).toBeInTheDocument();
expect(screen.getByText('col3')).toBeInTheDocument();
expect(screen.queryByText('col2')).not.toBeInTheDocument();
});
test('renders no columns when doesColumnMatchFilterType returns false', () => {
(doesColumnMatchFilterType as jest.Mock).mockReturnValue(false);
const props = {
...createProps(),
formFilter: { filterType: 'filterType' },
};
const element = getControlItemsMap(props).mainControlItems.groupby
.element as React.ReactElement;
render(element);
expect(screen.queryByText('col1')).not.toBeInTheDocument();
expect(screen.queryByText('col3')).not.toBeInTheDocument();
expect(screen.queryByText('col2')).not.toBeInTheDocument();
});
});

View File

@@ -131,7 +131,10 @@ export default function getControlItemsMap({
filterId={filterId} filterId={filterId}
datasetId={datasetId} datasetId={datasetId}
filterValues={column => filterValues={column =>
doesColumnMatchFilterType(formFilter?.filterType || '', column) doesColumnMatchFilterType(
formFilter?.filterType || '',
column,
) && column.filterable
} }
onChange={() => { onChange={() => {
// We need reset default value when column changed // We need reset default value when column changed

View File

@@ -32,6 +32,7 @@ from deprecation import deprecated
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError from sqlglot.errors import ParseError
from sqlglot.expressions import Func
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
@@ -453,6 +454,23 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
return SQLStatement(sql, self.engine, optimized) return SQLStatement(sql, self.engine, optimized)
def check_functions_present(self, functions: set[str]) -> bool:
"""
Check if any of the given functions are present in the script.
:param functions: List of functions to check for
:return: True if any of the functions are present
"""
present = {
(
function.sql_name()
if function.sql_name() != "ANONYMOUS"
else function.name.upper()
)
for function in self._parsed.find_all(Func)
}
return any(function.upper() in present for function in functions)
class KQLSplitState(enum.Enum): class KQLSplitState(enum.Enum):
""" """
@@ -619,6 +637,16 @@ class KustoKQLStatement(BaseSQLStatement[str]):
""" """
return KustoKQLStatement(self._sql, self.engine, self._parsed) return KustoKQLStatement(self._sql, self.engine, self._parsed)
def check_functions_present(self, functions: set[str]) -> bool:
"""
Check if any of the given functions are present in the script.
:param functions: List of functions to check for
:return: True if any of the functions are present
"""
logger.warning("Kusto KQL doesn't support checking for functions present.")
return True
class SQLScript: class SQLScript:
""" """
@@ -684,6 +712,18 @@ class SQLScript:
return script return script
def check_functions_present(self, functions: set[str]) -> bool:
"""
Check if any of the given functions are present in the script.
:param functions: List of functions to check for
:return: True if any of the functions are present
"""
return any(
statement.check_functions_present(functions)
for statement in self.statements
)
def extract_tables_from_statement( def extract_tables_from_statement(
statement: exp.Expression, statement: exp.Expression,

View File

@@ -31,7 +31,6 @@ from sqlalchemy import and_
from sqlparse import keywords from sqlparse import keywords
from sqlparse.lexer import Lexer from sqlparse.lexer import Lexer
from sqlparse.sql import ( from sqlparse.sql import (
Function,
Identifier, Identifier,
IdentifierList, IdentifierList,
Parenthesis, Parenthesis,
@@ -181,7 +180,7 @@ def check_sql_functions_exist(
:param function_list: The list of functions to search for :param function_list: The list of functions to search for
:param engine: The engine to use for parsing the SQL statement :param engine: The engine to use for parsing the SQL statement
""" """
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list) return SQLScript(sql, engine=engine).check_functions_present(function_list)
def strip_comments_from_sql(statement: str, engine: str = "base") -> str: def strip_comments_from_sql(statement: str, engine: str = "base") -> str:
@@ -229,34 +228,6 @@ class ParsedQuery:
self._tables = self._extract_tables_from_sql() self._tables = self._extract_tables_from_sql()
return self._tables return self._tables
def _check_functions_exist_in_token(
self, token: Token, functions: set[str]
) -> bool:
if (
isinstance(token, Function)
and token.get_name() is not None
and token.get_name().lower() in functions
):
return True
if hasattr(token, "tokens"):
for inner_token in token.tokens:
if self._check_functions_exist_in_token(inner_token, functions):
return True
return False
def check_functions_exist(self, functions: set[str]) -> bool:
"""
Check if the SQL statement contains any of the specified functions.
:param functions: A set of functions to search for
:return: True if the statement contains any of the specified functions
"""
for statement in self._parsed:
for token in statement.tokens:
if self._check_functions_exist_in_token(token, functions):
return True
return False
def _extract_tables_from_sql(self) -> set[Table]: def _extract_tables_from_sql(self) -> set[Table]:
""" """
Extract all table references in a query. Extract all table references in a query.

View File

@@ -580,10 +580,7 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
) )
rv = test_client.post(uri, json={}) rv = test_client.post(uri, json={})
assert rv.status_code == 422 assert rv.status_code == 422
assert rv.json["errors"][0]["error_type"] == "INVALID_SQL_ERROR"
assert "error" in rv.json
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
assert "INCORRECT SQL" in rv.json.get("error")
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True) @with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)

View File

@@ -1237,6 +1237,35 @@ def test_check_sql_functions_exist() -> None:
) )
def test_check_sql_functions_exist_with_comments() -> None:
"""
Test sql functions are detected correctly with comments
"""
assert not (
check_sql_functions_exist(
"select a, b from version/**/", {"version"}, "postgresql"
)
)
assert check_sql_functions_exist("select version/**/()", {"version"}, "postgresql")
assert check_sql_functions_exist(
"select version from version/**/()", {"version"}, "postgresql"
)
assert check_sql_functions_exist(
"select 1, a.version from (select version from version/**/()) as a",
{"version"},
"postgresql",
)
assert check_sql_functions_exist(
"select 1, a.version from (select version/**/()) as a",
{"version"},
"postgresql",
)
def test_sanitize_clause_valid(): def test_sanitize_clause_valid():
# regular clauses # regular clauses
assert sanitize_clause("col = 1") == "col = 1" assert sanitize_clause("col = 1") == "col = 1"