mirror of
https://github.com/apache/superset.git
synced 2026-04-28 04:25:07 +00:00
fix(OpenSearch): OpenSearch dialect for sqlglot (#39538)
This commit is contained in:
240
tests/unit_tests/sql/dialects/opensearch_tests.py
Normal file
240
tests/unit_tests/sql/dialects/opensearch_tests.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import pytest
|
||||
import sqlglot
|
||||
|
||||
from superset.sql.dialects.opensearch import OpenSearch
|
||||
|
||||
|
||||
def test_opensearch_dialect_registered() -> None:
|
||||
"""
|
||||
Test that OpenSearch dialect is properly registered for odelasticsearch.
|
||||
"""
|
||||
from superset.sql.parse import SQLGLOT_DIALECTS
|
||||
|
||||
assert "odelasticsearch" in SQLGLOT_DIALECTS
|
||||
assert SQLGLOT_DIALECTS["odelasticsearch"] == OpenSearch
|
||||
|
||||
|
||||
def test_double_quotes_as_identifiers() -> None:
|
||||
"""
|
||||
Test that double quotes are treated as identifiers, not string literals.
|
||||
"""
|
||||
sql = 'SELECT "AvgTicketPrice" FROM "flights"'
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
|
||||
assert (
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
"AvgTicketPrice"
|
||||
FROM "flights"
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_single_quotes_for_strings() -> None:
|
||||
"""
|
||||
Test that single quotes are used for string literals.
|
||||
"""
|
||||
sql = "SELECT * FROM flights WHERE Carrier = 'Kibana Airlines'"
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
|
||||
assert (
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
*
|
||||
FROM flights
|
||||
WHERE
|
||||
Carrier = 'Kibana Airlines'
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_backticks_as_identifiers() -> None:
|
||||
"""
|
||||
Test that backticks work as identifiers (MySQL-style).
|
||||
Backticks are normalized to double quotes in output.
|
||||
"""
|
||||
sql = "SELECT `AvgTicketPrice` FROM `flights`"
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
|
||||
assert (
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
"AvgTicketPrice"
|
||||
FROM "flights"
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_mixed_identifier_quotes() -> None:
|
||||
"""
|
||||
Test mixing double quotes and backticks for identifiers.
|
||||
"""
|
||||
sql = 'SELECT "AvgTicketPrice" AS `AvgTicketPrice` FROM `default`.`flights`'
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
|
||||
assert (
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
"AvgTicketPrice" AS "AvgTicketPrice"
|
||||
FROM "default"."flights"
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, expected",
|
||||
[
|
||||
(
|
||||
'SELECT COUNT(*) FROM "flights" WHERE "Cancelled" = true',
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM "flights"
|
||||
WHERE
|
||||
"Cancelled" = TRUE
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
'SELECT "Carrier", SUM("AvgTicketPrice") FROM "flights" GROUP BY "Carrier"',
|
||||
"""
|
||||
SELECT
|
||||
"Carrier",
|
||||
SUM("AvgTicketPrice")
|
||||
FROM "flights"
|
||||
GROUP BY
|
||||
"Carrier"
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM \"flights\" WHERE \"DestCountry\" IN ('US', 'CA', 'MX')",
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM "flights"
|
||||
WHERE
|
||||
"DestCountry" IN ('US', 'CA', 'MX')
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_various_queries(sql: str, expected: str) -> None:
|
||||
"""
|
||||
Test various SQL queries with OpenSearch dialect.
|
||||
"""
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
assert OpenSearch().generate(expression=ast, pretty=True) == expected
|
||||
|
||||
|
||||
def test_aggregate_functions() -> None:
|
||||
"""
|
||||
Test aggregate functions with quoted identifiers.
|
||||
"""
|
||||
sql = """
|
||||
SELECT
|
||||
"Carrier",
|
||||
COUNT(*),
|
||||
AVG("AvgTicketPrice"),
|
||||
MAX("FlightDelayMin")
|
||||
FROM "flights"
|
||||
GROUP BY "Carrier"
|
||||
"""
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
|
||||
assert (
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
"Carrier",
|
||||
COUNT(*),
|
||||
AVG("AvgTicketPrice"),
|
||||
MAX("FlightDelayMin")
|
||||
FROM "flights"
|
||||
GROUP BY
|
||||
"Carrier"
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_subquery_with_quoted_identifiers() -> None:
|
||||
"""
|
||||
Test subqueries with quoted identifiers.
|
||||
"""
|
||||
sql = 'SELECT * FROM (SELECT "Carrier", "AvgTicketPrice" FROM "flights") AS "sub"'
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
|
||||
assert (
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
"Carrier",
|
||||
"AvgTicketPrice"
|
||||
FROM "flights"
|
||||
) AS "sub"
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_order_by_with_quoted_identifiers() -> None:
|
||||
"""
|
||||
Test ORDER BY clause with quoted identifiers.
|
||||
"""
|
||||
sql = (
|
||||
'SELECT "Carrier", "AvgTicketPrice" FROM "flights" '
|
||||
'ORDER BY "AvgTicketPrice" DESC, "Carrier" ASC'
|
||||
)
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
|
||||
assert (
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
"Carrier",
|
||||
"AvgTicketPrice"
|
||||
FROM "flights"
|
||||
ORDER BY
|
||||
"AvgTicketPrice" DESC,
|
||||
"Carrier" ASC
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_limit_clause() -> None:
|
||||
"""
|
||||
Test LIMIT clause with quoted identifiers.
|
||||
"""
|
||||
sql = 'SELECT * FROM "flights" LIMIT 10'
|
||||
ast = sqlglot.parse_one(sql, OpenSearch)
|
||||
|
||||
assert (
|
||||
OpenSearch().generate(expression=ast, pretty=True)
|
||||
== """
|
||||
SELECT
|
||||
*
|
||||
FROM "flights"
|
||||
LIMIT 10
|
||||
""".strip()
|
||||
)
|
||||
Reference in New Issue
Block a user