diff --git a/superset/sql/dialects/__init__.py b/superset/sql/dialects/__init__.py index 71c8958a80f..0334efb5f11 100644 --- a/superset/sql/dialects/__init__.py +++ b/superset/sql/dialects/__init__.py @@ -18,6 +18,7 @@ from .db2 import DB2 from .dremio import Dremio from .firebolt import Firebolt, FireboltOld +from .opensearch import OpenSearch from .pinot import Pinot -__all__ = ["DB2", "Dremio", "Firebolt", "FireboltOld", "Pinot"] +__all__ = ["DB2", "Dremio", "Firebolt", "FireboltOld", "OpenSearch", "Pinot"] diff --git a/superset/sql/dialects/opensearch.py b/superset/sql/dialects/opensearch.py new file mode 100644 index 00000000000..5cde7469b68 --- /dev/null +++ b/superset/sql/dialects/opensearch.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +OpenSearch SQL dialect. + +OpenSearch SQL is syntactically close to MySQL but accepts both backticks and +double-quotes as identifier delimiters. Treating ``"`` as an identifier (rather +than a string delimiter, as MySQL does) is what keeps mixed-case column names +from being emitted as string literals after a SQLGlot round-trip. +""" + +from __future__ import annotations + +from sqlglot.dialects.mysql import MySQL + + +class OpenSearch(MySQL): + class Tokenizer(MySQL.Tokenizer): + IDENTIFIERS = ['"', "`"] diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 20fe7f2b0c8..6f07c15c164 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -45,7 +45,7 @@ from sqlglot.optimizer.scope import ( ) from superset.exceptions import QueryClauseValidationException, SupersetParseError -from superset.sql.dialects import DB2, Dremio, Firebolt, Pinot +from superset.sql.dialects import DB2, Dremio, Firebolt, OpenSearch, Pinot if TYPE_CHECKING: from superset.models.core import Database @@ -93,7 +93,7 @@ SQLGLOT_DIALECTS = { "netezza": Dialects.POSTGRES, "oceanbase": Dialects.MYSQL, # "ocient": ??? - # "odelasticsearch": ??? + "odelasticsearch": OpenSearch, "oracle": Dialects.ORACLE, "parseable": Dialects.POSTGRES, "pinot": Pinot, diff --git a/tests/unit_tests/sql/dialects/opensearch_tests.py b/tests/unit_tests/sql/dialects/opensearch_tests.py new file mode 100644 index 00000000000..c68c343a7ad --- /dev/null +++ b/tests/unit_tests/sql/dialects/opensearch_tests.py @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import sqlglot + +from superset.sql.dialects.opensearch import OpenSearch + + +def test_opensearch_dialect_registered() -> None: + """ + Test that OpenSearch dialect is properly registered for odelasticsearch. + """ + from superset.sql.parse import SQLGLOT_DIALECTS + + assert "odelasticsearch" in SQLGLOT_DIALECTS + assert SQLGLOT_DIALECTS["odelasticsearch"] == OpenSearch + + +def test_double_quotes_as_identifiers() -> None: + """ + Test that double quotes are treated as identifiers, not string literals. + """ + sql = 'SELECT "AvgTicketPrice" FROM "flights"' + ast = sqlglot.parse_one(sql, OpenSearch) + + assert ( + OpenSearch().generate(expression=ast, pretty=True) + == """ +SELECT + "AvgTicketPrice" +FROM "flights" + """.strip() + ) + + +def test_single_quotes_for_strings() -> None: + """ + Test that single quotes are used for string literals. + """ + sql = "SELECT * FROM flights WHERE Carrier = 'Kibana Airlines'" + ast = sqlglot.parse_one(sql, OpenSearch) + + assert ( + OpenSearch().generate(expression=ast, pretty=True) + == """ +SELECT + * +FROM flights +WHERE + Carrier = 'Kibana Airlines' + """.strip() + ) + + +def test_backticks_as_identifiers() -> None: + """ + Test that backticks work as identifiers (MySQL-style). + Backticks are normalized to double quotes in output. + """ + sql = "SELECT `AvgTicketPrice` FROM `flights`" + ast = sqlglot.parse_one(sql, OpenSearch) + + assert ( + OpenSearch().generate(expression=ast, pretty=True) + == """ +SELECT + "AvgTicketPrice" +FROM "flights" + """.strip() + ) + + +def test_mixed_identifier_quotes() -> None: + """ + Test mixing double quotes and backticks for identifiers. + """ + sql = 'SELECT "AvgTicketPrice" AS `AvgTicketPrice` FROM `default`.`flights`' + ast = sqlglot.parse_one(sql, OpenSearch) + + assert ( + OpenSearch().generate(expression=ast, pretty=True) + == """ +SELECT + "AvgTicketPrice" AS "AvgTicketPrice" +FROM "default"."flights" + """.strip() + ) + + +@pytest.mark.parametrize( + "sql, expected", + [ + ( + 'SELECT COUNT(*) FROM "flights" WHERE "Cancelled" = true', + """ +SELECT + COUNT(*) +FROM "flights" +WHERE + "Cancelled" = TRUE + """.strip(), + ), + ( + 'SELECT "Carrier", SUM("AvgTicketPrice") FROM "flights" GROUP BY "Carrier"', + """ +SELECT + "Carrier", + SUM("AvgTicketPrice") +FROM "flights" +GROUP BY + "Carrier" + """.strip(), + ), + ( + "SELECT * FROM \"flights\" WHERE \"DestCountry\" IN ('US', 'CA', 'MX')", + """ +SELECT + * +FROM "flights" +WHERE + "DestCountry" IN ('US', 'CA', 'MX') + """.strip(), + ), + ], +) +def test_various_queries(sql: str, expected: str) -> None: + """ + Test various SQL queries with OpenSearch dialect. + """ + ast = sqlglot.parse_one(sql, OpenSearch) + assert OpenSearch().generate(expression=ast, pretty=True) == expected + + +def test_aggregate_functions() -> None: + """ + Test aggregate functions with quoted identifiers. + """ + sql = """ +SELECT + "Carrier", + COUNT(*), + AVG("AvgTicketPrice"), + MAX("FlightDelayMin") +FROM "flights" +GROUP BY "Carrier" + """ + ast = sqlglot.parse_one(sql, OpenSearch) + + assert ( + OpenSearch().generate(expression=ast, pretty=True) + == """ +SELECT + "Carrier", + COUNT(*), + AVG("AvgTicketPrice"), + MAX("FlightDelayMin") +FROM "flights" +GROUP BY + "Carrier" + """.strip() + ) + + +def test_subquery_with_quoted_identifiers() -> None: + """ + Test subqueries with quoted identifiers. + """ + sql = 'SELECT * FROM (SELECT "Carrier", "AvgTicketPrice" FROM "flights") AS "sub"' + ast = sqlglot.parse_one(sql, OpenSearch) + + assert ( + OpenSearch().generate(expression=ast, pretty=True) + == """ +SELECT + * +FROM ( + SELECT + "Carrier", + "AvgTicketPrice" + FROM "flights" +) AS "sub" + """.strip() + ) + + +def test_order_by_with_quoted_identifiers() -> None: + """ + Test ORDER BY clause with quoted identifiers. + """ + sql = ( + 'SELECT "Carrier", "AvgTicketPrice" FROM "flights" ' + 'ORDER BY "AvgTicketPrice" DESC, "Carrier" ASC' + ) + ast = sqlglot.parse_one(sql, OpenSearch) + + assert ( + OpenSearch().generate(expression=ast, pretty=True) + == """ +SELECT + "Carrier", + "AvgTicketPrice" +FROM "flights" +ORDER BY + "AvgTicketPrice" DESC, + "Carrier" ASC + """.strip() + ) + + +def test_limit_clause() -> None: + """ + Test LIMIT clause with quoted identifiers. + """ + sql = 'SELECT * FROM "flights" LIMIT 10' + ast = sqlglot.parse_one(sql, OpenSearch) + + assert ( + OpenSearch().generate(expression=ast, pretty=True) + == """ +SELECT + * +FROM "flights" +LIMIT 10 + """.strip() + )