mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
feat(ag-grid): add SQLGlot-based SQL escaping for where and having filter clauses (#36136)
This commit is contained in:
@@ -74,6 +74,9 @@ export type QueryObjectExtras = Partial<{
|
||||
instant_time_comparison_range?: string;
|
||||
|
||||
time_compare?: string;
|
||||
|
||||
/** If true, WHERE/HAVING clauses need transpilation to target dialect */
|
||||
transpile_to_dialect?: boolean;
|
||||
}>;
|
||||
|
||||
export type ResidualQueryObjectData = {
|
||||
|
||||
@@ -453,7 +453,7 @@ const buildQuery: BuildQuery<TableChartFormData> = (
|
||||
];
|
||||
}
|
||||
|
||||
// Apply AG Grid filters converted to SQL WHERE/HAVING clauses
|
||||
// Apply AG Grid filters as SQL WHERE/HAVING clauses
|
||||
if (ownState.sqlClauses) {
|
||||
const { whereClause, havingClause } = classifySQLClauses(
|
||||
ownState.sqlClauses as Record<string, string>,
|
||||
@@ -462,6 +462,7 @@ const buildQuery: BuildQuery<TableChartFormData> = (
|
||||
if (whereClause || havingClause) {
|
||||
queryObject.extras = {
|
||||
...queryObject.extras,
|
||||
transpile_to_dialect: true,
|
||||
...(whereClause && {
|
||||
where: queryObject.extras?.where
|
||||
? `${queryObject.extras.where} AND ${whereClause}`
|
||||
|
||||
@@ -50,6 +50,11 @@ const NUMBER_FILTER_OPERATORS: Record<string, string> = {
|
||||
greaterThanOrEqual: '>=',
|
||||
};
|
||||
|
||||
/** Escapes single quotes in SQL strings: O'Hara → O''Hara */
|
||||
function escapeStringValue(value: string): string {
|
||||
return value.replace(/'/g, "''");
|
||||
}
|
||||
|
||||
function getTextComparator(type: string, value: string): string {
|
||||
if (type === 'contains' || type === 'notContains') {
|
||||
return `%${value}%`;
|
||||
@@ -134,10 +139,12 @@ function convertFilterToSQL(
|
||||
|
||||
if (filter.filterType === 'text' && filter.filter && filter.type) {
|
||||
const op = TEXT_FILTER_OPERATORS[filter.type];
|
||||
const val = getTextComparator(filter.type, String(filter.filter));
|
||||
const escapedFilter = escapeStringValue(String(filter.filter));
|
||||
const val = getTextComparator(filter.type, escapedFilter);
|
||||
|
||||
return op === 'ILIKE' || op === 'NOT ILIKE'
|
||||
? `${colId} ${op} '${val}'`
|
||||
: `${colId} ${op} '${filter.filter}'`;
|
||||
: `${colId} ${op} '${escapedFilter}'`;
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -151,7 +158,8 @@ function convertFilterToSQL(
|
||||
|
||||
if (filter.filterType === 'date' && filter.dateFrom && filter.type) {
|
||||
const op = NUMBER_FILTER_OPERATORS[filter.type];
|
||||
return `${colId} ${op} '${filter.dateFrom}'`;
|
||||
const escapedDate = escapeStringValue(filter.dateFrom);
|
||||
return `${colId} ${op} '${escapedDate}'`;
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -159,7 +167,9 @@ function convertFilterToSQL(
|
||||
Array.isArray(filter.values) &&
|
||||
filter.values.length > 0
|
||||
) {
|
||||
const values = filter.values.map((v: string) => `'${v}'`).join(', ');
|
||||
const values = filter.values
|
||||
.map((v: string) => `'${escapeStringValue(v)}'`)
|
||||
.join(', ');
|
||||
return `${colId} IN (${values})`;
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from superset.exceptions import (
|
||||
QueryObjectValidationError,
|
||||
)
|
||||
from superset.extensions import event_logger
|
||||
from superset.sql.parse import sanitize_clause
|
||||
from superset.sql.parse import sanitize_clause, transpile_to_dialect
|
||||
from superset.superset_typing import Column, Metric, OrderBy, QueryObjectDict
|
||||
from superset.utils import json, pandas_postprocessing
|
||||
from superset.utils.core import (
|
||||
@@ -337,6 +337,8 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||
def _sanitize_filters(self) -> None:
|
||||
from superset.jinja_context import get_template_processor
|
||||
|
||||
needs_transpilation = self.extras.get("transpile_to_dialect", False)
|
||||
|
||||
for param in ("where", "having"):
|
||||
clause = self.extras.get(param)
|
||||
if clause and self.datasource:
|
||||
@@ -352,7 +354,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||
msg=ex.message,
|
||||
)
|
||||
) from ex
|
||||
|
||||
engine = database.db_engine_spec.engine
|
||||
|
||||
if needs_transpilation:
|
||||
clause = transpile_to_dialect(clause, engine)
|
||||
|
||||
sanitized_clause = sanitize_clause(clause, engine)
|
||||
if sanitized_clause != clause:
|
||||
self.extras[param] = sanitized_clause
|
||||
|
||||
@@ -1503,3 +1503,31 @@ def sanitize_clause(clause: str, engine: str) -> str:
|
||||
)
|
||||
except SupersetParseError as ex:
|
||||
raise QueryClauseValidationException(f"Invalid SQL clause: {clause}") from ex
|
||||
|
||||
|
||||
def transpile_to_dialect(sql: str, target_engine: str) -> str:
|
||||
"""
|
||||
Transpile SQL from "generic SQL" to the target database dialect using SQLGlot.
|
||||
|
||||
If the target engine is not in SQLGLOT_DIALECTS, returns the SQL as-is.
|
||||
"""
|
||||
target_dialect = SQLGLOT_DIALECTS.get(target_engine)
|
||||
|
||||
# If no dialect mapping exists, return as-is
|
||||
if target_dialect is None:
|
||||
return sql
|
||||
|
||||
try:
|
||||
parsed = sqlglot.parse_one(sql, dialect=Dialect)
|
||||
return Dialect.get_or_raise(target_dialect).generate(
|
||||
parsed,
|
||||
copy=True,
|
||||
comments=False,
|
||||
pretty=False,
|
||||
)
|
||||
except ParseError as ex:
|
||||
raise QueryClauseValidationException(f"Cannot parse SQL clause: {sql}") from ex
|
||||
except Exception as ex:
|
||||
raise QueryClauseValidationException(
|
||||
f"Cannot transpile SQL to {target_engine}: {sql}"
|
||||
) from ex
|
||||
|
||||
347
tests/unit_tests/sql/transpile_to_dialect_test.py
Normal file
347
tests/unit_tests/sql/transpile_to_dialect_test.py
Normal file
@@ -0,0 +1,347 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Tests for transpile_to_dialect function in superset/sql/parse.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.sql.parse import transpile_to_dialect
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
# PostgreSQL - SQL-92 standard double single quotes
|
||||
("name = 'O''Hara'", "postgresql", "name = 'O''Hara'"),
|
||||
# MySQL - SQL-92 standard double single quotes
|
||||
("name = 'O''Hara'", "mysql", "name = 'O''Hara'"),
|
||||
# SQLite - SQL-92 standard double single quotes
|
||||
("name = 'O''Hara'", "sqlite", "name = 'O''Hara'"),
|
||||
# Snowflake - backslash escaping
|
||||
("name = 'O''Hara'", "snowflake", "name = 'O\\'Hara'"),
|
||||
# BigQuery - backslash escaping
|
||||
("name = 'O''Hara'", "bigquery", "name = 'O\\'Hara'"),
|
||||
# Databricks - backslash escaping
|
||||
("name = 'O''Hara'", "databricks", "name = 'O\\'Hara'"),
|
||||
# Presto - SQL-92 standard double single quotes
|
||||
("name = 'O''Hara'", "presto", "name = 'O''Hara'"),
|
||||
# Trino - SQL-92 standard double single quotes
|
||||
("name = 'O''Hara'", "trino", "name = 'O''Hara'"),
|
||||
],
|
||||
)
|
||||
def test_single_quote_escaping(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test single quote escaping across different database dialects."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
(
|
||||
"(name = 'O''Hara' AND status = 'active')",
|
||||
"postgresql",
|
||||
"(name = 'O''Hara' AND status = 'active')",
|
||||
),
|
||||
(
|
||||
"(name = 'O''Hara' AND status = 'active')",
|
||||
"mysql",
|
||||
"(name = 'O''Hara' AND status = 'active')",
|
||||
),
|
||||
(
|
||||
"(name = 'O''Hara' AND status = 'active')",
|
||||
"snowflake",
|
||||
"(name = 'O\\'Hara' AND status = 'active')",
|
||||
),
|
||||
(
|
||||
"(name = 'O''Hara' AND status = 'active')",
|
||||
"databricks",
|
||||
"(name = 'O\\'Hara' AND status = 'active')",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compound_filter_with_quotes(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test compound filters with quoted strings."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
("name LIKE '%O''Hara%'", "postgresql", "name LIKE '%O''Hara%'"),
|
||||
("name LIKE '%O''Hara%'", "mysql", "name LIKE '%O''Hara%'"),
|
||||
("name LIKE '%O''Hara%'", "snowflake", "name LIKE '%O\\'Hara%'"),
|
||||
("name LIKE '%O''Hara%'", "databricks", "name LIKE '%O\\'Hara%'"),
|
||||
],
|
||||
)
|
||||
def test_like_with_special_chars(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test LIKE patterns with special characters."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
# PostgreSQL keeps ILIKE
|
||||
("name ILIKE '%test%'", "postgresql", "name ILIKE '%test%'"),
|
||||
# MySQL converts ILIKE to LOWER(col) LIKE LOWER(pattern)
|
||||
("name ILIKE '%test%'", "mysql", "LOWER(name) LIKE LOWER('%test%')"),
|
||||
# SQLite converts ILIKE to LOWER(col) LIKE LOWER(pattern)
|
||||
("name ILIKE '%test%'", "sqlite", "LOWER(name) LIKE LOWER('%test%')"),
|
||||
# Snowflake keeps ILIKE
|
||||
("name ILIKE '%test%'", "snowflake", "name ILIKE '%test%'"),
|
||||
# BigQuery converts ILIKE to LOWER(col) LIKE LOWER(pattern)
|
||||
("name ILIKE '%test%'", "bigquery", "LOWER(name) LIKE LOWER('%test%')"),
|
||||
# Databricks keeps ILIKE
|
||||
("name ILIKE '%test%'", "databricks", "name ILIKE '%test%'"),
|
||||
# Presto converts ILIKE to LOWER(col) LIKE LOWER(pattern)
|
||||
("name ILIKE '%test%'", "presto", "LOWER(name) LIKE LOWER('%test%')"),
|
||||
# Trino converts ILIKE to LOWER(col) LIKE LOWER(pattern)
|
||||
("name ILIKE '%test%'", "trino", "LOWER(name) LIKE LOWER('%test%')"),
|
||||
],
|
||||
)
|
||||
def test_ilike_transpilation(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test ILIKE transpilation to various database dialects."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
(
|
||||
"name IN ('O''Hara', 'D''Angelo')",
|
||||
"postgresql",
|
||||
"name IN ('O''Hara', 'D''Angelo')",
|
||||
),
|
||||
(
|
||||
"name IN ('O''Hara', 'D''Angelo')",
|
||||
"mysql",
|
||||
"name IN ('O''Hara', 'D''Angelo')",
|
||||
),
|
||||
(
|
||||
"name IN ('O''Hara', 'D''Angelo')",
|
||||
"snowflake",
|
||||
"name IN ('O\\'Hara', 'D\\'Angelo')",
|
||||
),
|
||||
(
|
||||
"name IN ('O''Hara', 'D''Angelo')",
|
||||
"databricks",
|
||||
"name IN ('O\\'Hara', 'D\\'Angelo')",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_in_clause_with_quoted_strings(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test IN clause with multiple quoted strings."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
(
|
||||
"price > 100 AND quantity <= 50",
|
||||
"postgresql",
|
||||
"price > 100 AND quantity <= 50",
|
||||
),
|
||||
("price > 100 AND quantity <= 50", "mysql", "price > 100 AND quantity <= 50"),
|
||||
(
|
||||
"price > 100 AND quantity <= 50",
|
||||
"snowflake",
|
||||
"price > 100 AND quantity <= 50",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_number_comparison(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test number comparisons are preserved."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
("created_at > '2024-01-01'", "postgresql", "created_at > '2024-01-01'"),
|
||||
("created_at > '2024-01-01'", "mysql", "created_at > '2024-01-01'"),
|
||||
("created_at > '2024-01-01'", "snowflake", "created_at > '2024-01-01'"),
|
||||
],
|
||||
)
|
||||
def test_date_comparison(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test date comparisons."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
("price BETWEEN 10 AND 100", "postgresql", "price BETWEEN 10 AND 100"),
|
||||
("price BETWEEN 10 AND 100", "mysql", "price BETWEEN 10 AND 100"),
|
||||
("price BETWEEN 10 AND 100", "snowflake", "price BETWEEN 10 AND 100"),
|
||||
],
|
||||
)
|
||||
def test_between_clause(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test BETWEEN clause."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
("name IS NULL", "postgresql", "name IS NULL"),
|
||||
("name IS NULL", "mysql", "name IS NULL"),
|
||||
("name IS NULL", "snowflake", "name IS NULL"),
|
||||
],
|
||||
)
|
||||
def test_is_null(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test IS NULL clause."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
# SQLGlot normalizes "IS NOT NULL" to "NOT ... IS NULL"
|
||||
("name IS NOT NULL", "postgresql", "NOT name IS NULL"),
|
||||
("name IS NOT NULL", "mysql", "NOT name IS NULL"),
|
||||
("name IS NOT NULL", "snowflake", "NOT name IS NULL"),
|
||||
],
|
||||
)
|
||||
def test_is_not_null(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test IS NOT NULL clause (SQLGlot normalizes to NOT ... IS NULL)."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
(
|
||||
"status = 'active' OR status = 'pending'",
|
||||
"postgresql",
|
||||
"status = 'active' OR status = 'pending'",
|
||||
),
|
||||
(
|
||||
"status = 'active' OR status = 'pending'",
|
||||
"mysql",
|
||||
"status = 'active' OR status = 'pending'",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_or_condition(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test OR condition."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,dialect,expected",
|
||||
[
|
||||
(
|
||||
"((a > 1 AND b < 2) OR (c = 3))",
|
||||
"postgresql",
|
||||
"((a > 1 AND b < 2) OR (c = 3))",
|
||||
),
|
||||
("((a > 1 AND b < 2) OR (c = 3))", "mysql", "((a > 1 AND b < 2) OR (c = 3))"),
|
||||
],
|
||||
)
|
||||
def test_nested_conditions(sql: str, dialect: str, expected: str) -> None:
|
||||
"""Test nested parentheses conditions."""
|
||||
assert transpile_to_dialect(sql, dialect) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dialect",
|
||||
[
|
||||
"postgresql",
|
||||
"mysql",
|
||||
"sqlite",
|
||||
"snowflake",
|
||||
"bigquery",
|
||||
"mssql",
|
||||
"databricks",
|
||||
"presto",
|
||||
"trino",
|
||||
# Unknown engines should return SQL unchanged
|
||||
"unknown_database_engine",
|
||||
"crate",
|
||||
"databend",
|
||||
"db2",
|
||||
"denodo",
|
||||
"dynamodb",
|
||||
"elasticsearch",
|
||||
],
|
||||
)
|
||||
def test_transpilation_does_not_error(dialect: str) -> None:
|
||||
"""Verify transpilation does not raise errors for known and unknown dialects."""
|
||||
sql = "name = 'test' AND price > 100"
|
||||
# Should not raise an exception
|
||||
result = transpile_to_dialect(sql, dialect)
|
||||
assert result is not None
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"engine",
|
||||
[
|
||||
"unknown_database_engine",
|
||||
"crate",
|
||||
"databend",
|
||||
"db2",
|
||||
"denodo",
|
||||
"dynamodb",
|
||||
"elasticsearch",
|
||||
],
|
||||
)
|
||||
def test_unknown_engine_returns_sql_unchanged(engine: str) -> None:
|
||||
"""Test that unknown engines return SQL unchanged."""
|
||||
sql = "name = 'O''Hara'"
|
||||
assert transpile_to_dialect(sql, engine) == sql
|
||||
|
||||
|
||||
def test_invalid_sql_raises_exception() -> None:
|
||||
"""Test that invalid SQL raises QueryClauseValidationException."""
|
||||
from superset.exceptions import QueryClauseValidationException
|
||||
|
||||
with pytest.raises(QueryClauseValidationException):
|
||||
transpile_to_dialect("INVALID SQL !!!", "postgresql")
|
||||
|
||||
|
||||
def test_empty_sql_raises_exception() -> None:
|
||||
"""Test that empty SQL raises exception."""
|
||||
from superset.exceptions import QueryClauseValidationException
|
||||
|
||||
with pytest.raises(QueryClauseValidationException):
|
||||
transpile_to_dialect("", "postgresql")
|
||||
|
||||
|
||||
def test_sqlglot_generation_error_raises_exception() -> None:
|
||||
"""Test that SQLGlot generation errors are caught and wrapped."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from superset.exceptions import QueryClauseValidationException
|
||||
|
||||
# Create a mock parsed expression
|
||||
mock_parsed = MagicMock()
|
||||
|
||||
# Mock parse_one to succeed, then make generate fail
|
||||
with patch("superset.sql.parse.sqlglot.parse_one", return_value=mock_parsed):
|
||||
with patch("superset.sql.parse.Dialect.get_or_raise") as mock_get_dialect:
|
||||
mock_dialect = mock_get_dialect.return_value
|
||||
mock_dialect.generate.side_effect = RuntimeError("SQLGlot internal error")
|
||||
|
||||
with pytest.raises(
|
||||
QueryClauseValidationException,
|
||||
match="Cannot transpile SQL to postgresql",
|
||||
):
|
||||
transpile_to_dialect("name = 'test'", "postgresql")
|
||||
Reference in New Issue
Block a user