fix: OpenSearch dialect identifier delimiters (#39953)

This commit is contained in:
Vitor Avila
2026-05-07 16:19:27 -03:00
committed by GitHub
parent aa710672ed
commit ad5e3170dd
3 changed files with 51 additions and 38 deletions

View File

@@ -142,7 +142,7 @@ druid = ["pydruid>=0.6.5,<0.7"]
duckdb = ["duckdb>=1.4.2,<2", "duckdb-engine>=0.17.0"]
dynamodb = ["pydynamodb>=0.4.2"]
solr = ["sqlalchemy-solr >= 0.2.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.13, <0.3.0"]
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
excel = ["xlrd>=1.2.0, <1.3"]
fastmcp = [

View File

@@ -19,9 +19,7 @@
OpenSearch SQL dialect.
OpenSearch SQL is syntactically close to MySQL but accepts both backticks and
double-quotes as identifier delimiters. Treating ``"`` as an identifier (rather
than a string delimiter, as MySQL does) is what keeps mixed-case column names
from being emitted as string literals after a SQLGlot round-trip.
double-quotes as identifier delimiters.
"""
from __future__ import annotations
@@ -31,4 +29,4 @@ from sqlglot.dialects.mysql import MySQL
class OpenSearch(MySQL):
class Tokenizer(MySQL.Tokenizer):
IDENTIFIERS = ['"', "`"]
IDENTIFIERS = ["`", '"']

View File

@@ -33,7 +33,8 @@ def test_opensearch_dialect_registered() -> None:
def test_double_quotes_as_identifiers() -> None:
"""
Test that double quotes are treated as identifiers, not string literals.
Test that double quotes are treated as identifiers, not string literals,
and normalized to backticks in output.
"""
sql = 'SELECT "AvgTicketPrice" FROM "flights"'
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -42,8 +43,8 @@ def test_double_quotes_as_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice"
FROM "flights"
`AvgTicketPrice`
FROM `flights`
""".strip()
)
@@ -69,8 +70,7 @@ WHERE
def test_backticks_as_identifiers() -> None:
"""
Test that backticks work as identifiers (MySQL-style).
Backticks are normalized to double quotes in output.
Test that backticks are accepted as identifiers and preserved on output.
"""
sql = "SELECT `AvgTicketPrice` FROM `flights`"
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -79,15 +79,16 @@ def test_backticks_as_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice"
FROM "flights"
`AvgTicketPrice`
FROM `flights`
""".strip()
)
def test_mixed_identifier_quotes() -> None:
"""
Test mixing double quotes and backticks for identifiers.
Test mixing double quotes and backticks for identifiers are all normalized to
backticks on output.
"""
sql = 'SELECT "AvgTicketPrice" AS `AvgTicketPrice` FROM `default`.`flights`'
ast = sqlglot.parse_one(sql, OpenSearch)
@@ -96,12 +97,26 @@ def test_mixed_identifier_quotes() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"AvgTicketPrice" AS "AvgTicketPrice"
FROM "default"."flights"
`AvgTicketPrice` AS `AvgTicketPrice`
FROM `default`.`flights`
""".strip()
)
def test_alias_with_space() -> None:
"""
Test that an alias containing a space (e.g. a metric key like ``my test``)
is preserved as a backtick-quoted identifier through the round-trip.
"""
sql = 'SELECT COUNT(*) AS "my test" FROM `flights`'
ast = sqlglot.parse_one(sql, OpenSearch)
assert (
OpenSearch().generate(expression=ast, pretty=False)
== "SELECT COUNT(*) AS `my test` FROM `flights`"
)
@pytest.mark.parametrize(
"sql, expected",
[
@@ -110,20 +125,20 @@ FROM "default"."flights"
"""
SELECT
COUNT(*)
FROM "flights"
FROM `flights`
WHERE
"Cancelled" = TRUE
`Cancelled` = TRUE
""".strip(),
),
(
'SELECT "Carrier", SUM("AvgTicketPrice") FROM "flights" GROUP BY "Carrier"',
"""
SELECT
"Carrier",
SUM("AvgTicketPrice")
FROM "flights"
`Carrier`,
SUM(`AvgTicketPrice`)
FROM `flights`
GROUP BY
"Carrier"
`Carrier`
""".strip(),
),
(
@@ -131,9 +146,9 @@ GROUP BY
"""
SELECT
*
FROM "flights"
FROM `flights`
WHERE
"DestCountry" IN ('US', 'CA', 'MX')
`DestCountry` IN ('US', 'CA', 'MX')
""".strip(),
),
],
@@ -165,13 +180,13 @@ GROUP BY "Carrier"
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"Carrier",
`Carrier`,
COUNT(*),
AVG("AvgTicketPrice"),
MAX("FlightDelayMin")
FROM "flights"
AVG(`AvgTicketPrice`),
MAX(`FlightDelayMin`)
FROM `flights`
GROUP BY
"Carrier"
`Carrier`
""".strip()
)
@@ -190,10 +205,10 @@ SELECT
*
FROM (
SELECT
"Carrier",
"AvgTicketPrice"
FROM "flights"
) AS "sub"
`Carrier`,
`AvgTicketPrice`
FROM `flights`
) AS `sub`
""".strip()
)
@@ -212,12 +227,12 @@ def test_order_by_with_quoted_identifiers() -> None:
OpenSearch().generate(expression=ast, pretty=True)
== """
SELECT
"Carrier",
"AvgTicketPrice"
FROM "flights"
`Carrier`,
`AvgTicketPrice`
FROM `flights`
ORDER BY
"AvgTicketPrice" DESC,
"Carrier" ASC
`AvgTicketPrice` DESC,
`Carrier` ASC
""".strip()
)
@@ -234,7 +249,7 @@ def test_limit_clause() -> None:
== """
SELECT
*
FROM "flights"
FROM `flights`
LIMIT 10
""".strip()
)