diff --git a/pyproject.toml b/pyproject.toml index 933e5175971..210bd260074 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,7 +142,7 @@ druid = ["pydruid>=0.6.5,<0.7"] duckdb = ["duckdb>=1.4.2,<2", "duckdb-engine>=0.17.0"] dynamodb = ["pydynamodb>=0.4.2"] solr = ["sqlalchemy-solr >= 0.2.0"] -elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"] +elasticsearch = ["elasticsearch-dbapi>=0.2.13, <0.3.0"] exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"] excel = ["xlrd>=1.2.0, <1.3"] fastmcp = [ diff --git a/superset/sql/dialects/opensearch.py b/superset/sql/dialects/opensearch.py index 5cde7469b68..0f647c2f21e 100644 --- a/superset/sql/dialects/opensearch.py +++ b/superset/sql/dialects/opensearch.py @@ -19,9 +19,7 @@ OpenSearch SQL dialect. OpenSearch SQL is syntactically close to MySQL but accepts both backticks and -double-quotes as identifier delimiters. Treating ``"`` as an identifier (rather -than a string delimiter, as MySQL does) is what keeps mixed-case column names -from being emitted as string literals after a SQLGlot round-trip. +double-quotes as identifier delimiters. """ from __future__ import annotations @@ -31,4 +29,4 @@ from sqlglot.dialects.mysql import MySQL class OpenSearch(MySQL): class Tokenizer(MySQL.Tokenizer): - IDENTIFIERS = ['"', "`"] + IDENTIFIERS = ["`", '"'] diff --git a/tests/unit_tests/sql/dialects/opensearch_tests.py b/tests/unit_tests/sql/dialects/opensearch_tests.py index c68c343a7ad..8805fa4522b 100644 --- a/tests/unit_tests/sql/dialects/opensearch_tests.py +++ b/tests/unit_tests/sql/dialects/opensearch_tests.py @@ -33,7 +33,8 @@ def test_opensearch_dialect_registered() -> None: def test_double_quotes_as_identifiers() -> None: """ - Test that double quotes are treated as identifiers, not string literals. + Test that double quotes are treated as identifiers, not string literals, + and normalized to backticks in output. """ sql = 'SELECT "AvgTicketPrice" FROM "flights"' ast = sqlglot.parse_one(sql, OpenSearch) @@ -42,8 +43,8 @@ def test_double_quotes_as_identifiers() -> None: OpenSearch().generate(expression=ast, pretty=True) == """ SELECT - "AvgTicketPrice" -FROM "flights" + `AvgTicketPrice` +FROM `flights` """.strip() ) @@ -69,8 +70,7 @@ WHERE def test_backticks_as_identifiers() -> None: """ - Test that backticks work as identifiers (MySQL-style). - Backticks are normalized to double quotes in output. + Test that backticks are accepted as identifiers and preserved on output. """ sql = "SELECT `AvgTicketPrice` FROM `flights`" ast = sqlglot.parse_one(sql, OpenSearch) @@ -79,15 +79,16 @@ def test_backticks_as_identifiers() -> None: OpenSearch().generate(expression=ast, pretty=True) == """ SELECT - "AvgTicketPrice" -FROM "flights" + `AvgTicketPrice` +FROM `flights` """.strip() ) def test_mixed_identifier_quotes() -> None: """ - Test mixing double quotes and backticks for identifiers. + Test mixing double quotes and backticks for identifiers are all normalized to + backticks on output. """ sql = 'SELECT "AvgTicketPrice" AS `AvgTicketPrice` FROM `default`.`flights`' ast = sqlglot.parse_one(sql, OpenSearch) @@ -96,12 +97,26 @@ def test_mixed_identifier_quotes() -> None: OpenSearch().generate(expression=ast, pretty=True) == """ SELECT - "AvgTicketPrice" AS "AvgTicketPrice" -FROM "default"."flights" + `AvgTicketPrice` AS `AvgTicketPrice` +FROM `default`.`flights` """.strip() ) +def test_alias_with_space() -> None: + """ + Test that an alias containing a space (e.g. a metric key like ``my test``) + is preserved as a backtick-quoted identifier through the round-trip. + """ + sql = 'SELECT COUNT(*) AS "my test" FROM `flights`' + ast = sqlglot.parse_one(sql, OpenSearch) + + assert ( + OpenSearch().generate(expression=ast, pretty=False) + == "SELECT COUNT(*) AS `my test` FROM `flights`" + ) + + @pytest.mark.parametrize( "sql, expected", [ @@ -110,20 +125,20 @@ FROM "default"."flights" """ SELECT COUNT(*) -FROM "flights" +FROM `flights` WHERE - "Cancelled" = TRUE + `Cancelled` = TRUE """.strip(), ), ( 'SELECT "Carrier", SUM("AvgTicketPrice") FROM "flights" GROUP BY "Carrier"', """ SELECT - "Carrier", - SUM("AvgTicketPrice") -FROM "flights" + `Carrier`, + SUM(`AvgTicketPrice`) +FROM `flights` GROUP BY - "Carrier" + `Carrier` """.strip(), ), ( @@ -131,9 +146,9 @@ GROUP BY """ SELECT * -FROM "flights" +FROM `flights` WHERE - "DestCountry" IN ('US', 'CA', 'MX') + `DestCountry` IN ('US', 'CA', 'MX') """.strip(), ), ], @@ -165,13 +180,13 @@ GROUP BY "Carrier" OpenSearch().generate(expression=ast, pretty=True) == """ SELECT - "Carrier", + `Carrier`, COUNT(*), - AVG("AvgTicketPrice"), - MAX("FlightDelayMin") -FROM "flights" + AVG(`AvgTicketPrice`), + MAX(`FlightDelayMin`) +FROM `flights` GROUP BY - "Carrier" + `Carrier` """.strip() ) @@ -190,10 +205,10 @@ SELECT * FROM ( SELECT - "Carrier", - "AvgTicketPrice" - FROM "flights" -) AS "sub" + `Carrier`, + `AvgTicketPrice` + FROM `flights` +) AS `sub` """.strip() ) @@ -212,12 +227,12 @@ def test_order_by_with_quoted_identifiers() -> None: OpenSearch().generate(expression=ast, pretty=True) == """ SELECT - "Carrier", - "AvgTicketPrice" -FROM "flights" + `Carrier`, + `AvgTicketPrice` +FROM `flights` ORDER BY - "AvgTicketPrice" DESC, - "Carrier" ASC + `AvgTicketPrice` DESC, + `Carrier` ASC """.strip() ) @@ -234,7 +249,7 @@ def test_limit_clause() -> None: == """ SELECT * -FROM "flights" +FROM `flights` LIMIT 10 """.strip() )