mirror of
https://github.com/apache/superset.git
synced 2026-04-08 02:45:22 +00:00
Co-authored-by: bito-code-review[bot] <188872107+bito-code-review[bot]@users.noreply.github.com>
501 lines
12 KiB
Python
501 lines
12 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import pytest
|
|
import sqlglot
|
|
|
|
from superset.sql.dialects.pinot import Pinot
|
|
|
|
|
|
def test_pinot_dialect_registered() -> None:
|
|
"""Test that Pinot dialect is properly registered."""
|
|
from superset.sql.parse import SQLGLOT_DIALECTS
|
|
|
|
assert "pinot" in SQLGLOT_DIALECTS
|
|
assert SQLGLOT_DIALECTS["pinot"] == Pinot
|
|
|
|
|
|
def test_double_quotes_as_identifiers() -> None:
|
|
"""
|
|
Test that double quotes are treated as identifiers, not string literals.
|
|
"""
|
|
sql = 'SELECT "column_name" FROM "table_name"'
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"column_name"
|
|
FROM "table_name"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_single_quotes_for_strings() -> None:
|
|
"""
|
|
Test that single quotes are used for string literals.
|
|
"""
|
|
sql = "SELECT * FROM users WHERE name = 'John'"
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
*
|
|
FROM users
|
|
WHERE
|
|
name = 'John'
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_backticks_as_identifiers() -> None:
|
|
"""
|
|
Test that backticks work as identifiers (MySQL-style).
|
|
Backticks are normalized to double quotes in output.
|
|
"""
|
|
sql = "SELECT `column_name` FROM `table_name`"
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"column_name"
|
|
FROM "table_name"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_mixed_identifier_quotes() -> None:
|
|
"""
|
|
Test mixing double quotes and backticks for identifiers.
|
|
All identifiers are normalized to double quotes in output.
|
|
"""
|
|
sql = (
|
|
'SELECT "col1", `col2` FROM "table1" JOIN `table2` ON "table1".id = `table2`.id'
|
|
)
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"col1",
|
|
"col2"
|
|
FROM "table1"
|
|
JOIN "table2"
|
|
ON "table1".id = "table2".id
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_string_with_escaped_quotes() -> None:
|
|
"""
|
|
Test string literals with escaped single quotes.
|
|
"""
|
|
sql = "SELECT * FROM users WHERE name = 'O''Brien'"
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
*
|
|
FROM users
|
|
WHERE
|
|
name = 'O''Brien'
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_string_with_backslash_escape() -> None:
|
|
"""
|
|
Test string literals with backslash escapes.
|
|
"""
|
|
sql = r"SELECT * FROM users WHERE path = 'C:\\Users\\John'"
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
generated = Pinot().generate(expression=ast, pretty=True)
|
|
assert "WHERE" in generated
|
|
assert "path" in generated
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, expected",
|
|
[
|
|
(
|
|
'SELECT COUNT(*) FROM "events" WHERE "type" = \'click\'',
|
|
"""
|
|
SELECT
|
|
COUNT(*)
|
|
FROM "events"
|
|
WHERE
|
|
"type" = 'click'
|
|
""".strip(),
|
|
),
|
|
(
|
|
'SELECT "user_id", SUM("amount") FROM "transactions" GROUP BY "user_id"',
|
|
"""
|
|
SELECT
|
|
"user_id",
|
|
SUM("amount")
|
|
FROM "transactions"
|
|
GROUP BY
|
|
"user_id"
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM \"orders\" WHERE \"status\" IN ('pending', 'shipped')",
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM "orders"
|
|
WHERE
|
|
"status" IN ('pending', 'shipped')
|
|
""".strip(),
|
|
),
|
|
],
|
|
)
|
|
def test_various_queries(sql: str, expected: str) -> None:
|
|
"""
|
|
Test various SQL queries with Pinot dialect.
|
|
"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
assert Pinot().generate(expression=ast, pretty=True) == expected
|
|
|
|
|
|
def test_aggregate_functions() -> None:
|
|
"""
|
|
Test aggregate functions with quoted identifiers.
|
|
"""
|
|
sql = """
|
|
SELECT
|
|
"category",
|
|
COUNT(*),
|
|
AVG("price"),
|
|
MAX("quantity")
|
|
FROM "products"
|
|
GROUP BY "category"
|
|
"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"category",
|
|
COUNT(*),
|
|
AVG("price"),
|
|
MAX("quantity")
|
|
FROM "products"
|
|
GROUP BY
|
|
"category"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_join_with_quoted_identifiers() -> None:
|
|
"""
|
|
Test JOIN operations with double-quoted identifiers.
|
|
"""
|
|
sql = """
|
|
SELECT "u"."name", "o"."total"
|
|
FROM "users" AS "u"
|
|
JOIN "orders" AS "o" ON "u"."id" = "o"."user_id"
|
|
"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"u"."name",
|
|
"o"."total"
|
|
FROM "users" AS "u"
|
|
JOIN "orders" AS "o"
|
|
ON "u"."id" = "o"."user_id"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_subquery_with_quoted_identifiers() -> None:
|
|
"""
|
|
Test subqueries with double-quoted identifiers.
|
|
"""
|
|
sql = 'SELECT * FROM (SELECT "id", "name" FROM "users") AS "subquery"'
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
"id",
|
|
"name"
|
|
FROM "users"
|
|
) AS "subquery"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_case_expression() -> None:
|
|
"""
|
|
Test CASE expressions with quoted identifiers.
|
|
"""
|
|
sql = """
|
|
SELECT "name",
|
|
CASE WHEN "age" < 18 THEN 'minor'
|
|
WHEN "age" >= 18 THEN 'adult'
|
|
END AS "category"
|
|
FROM "persons"
|
|
"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
generated = Pinot().generate(expression=ast, pretty=True)
|
|
assert '"name"' in generated
|
|
assert '"age"' in generated
|
|
assert '"category"' in generated
|
|
assert "'minor'" in generated
|
|
assert "'adult'" in generated
|
|
|
|
|
|
def test_cte_with_quoted_identifiers() -> None:
|
|
"""
|
|
Test Common Table Expressions (CTE) with quoted identifiers.
|
|
"""
|
|
sql = """
|
|
WITH "high_value_orders" AS (
|
|
SELECT * FROM "orders" WHERE "total" > 1000
|
|
)
|
|
SELECT "customer_id", COUNT(*) FROM "high_value_orders" GROUP BY "customer_id"
|
|
"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
generated = Pinot().generate(expression=ast, pretty=True)
|
|
assert 'WITH "high_value_orders" AS' in generated
|
|
assert '"orders"' in generated
|
|
assert '"total"' in generated
|
|
assert '"customer_id"' in generated
|
|
|
|
|
|
def test_order_by_with_quoted_identifiers() -> None:
|
|
"""
|
|
Test ORDER BY clause with quoted identifiers.
|
|
SQLGlot explicitly includes ASC in the output.
|
|
"""
|
|
sql = 'SELECT "name", "salary" FROM "employees" ORDER BY "salary" DESC, "name" ASC'
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"name",
|
|
"salary"
|
|
FROM "employees"
|
|
ORDER BY
|
|
"salary" DESC,
|
|
"name" ASC
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_limit_and_offset() -> None:
|
|
"""
|
|
Test LIMIT and OFFSET clauses.
|
|
"""
|
|
sql = 'SELECT * FROM "products" LIMIT 10 OFFSET 20'
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
generated = Pinot().generate(expression=ast, pretty=True)
|
|
assert '"products"' in generated
|
|
assert "LIMIT 10" in generated
|
|
|
|
|
|
def test_distinct() -> None:
|
|
"""
|
|
Test DISTINCT keyword with quoted identifiers.
|
|
"""
|
|
sql = 'SELECT DISTINCT "category" FROM "products"'
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT DISTINCT
|
|
"category"
|
|
FROM "products"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_cast_to_string() -> None:
|
|
"""
|
|
Test that CAST to STRING is preserved (not converted to CHAR).
|
|
"""
|
|
sql = "SELECT CAST(cohort_size AS STRING) FROM table"
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
generated = Pinot().generate(expression=ast)
|
|
|
|
assert "STRING" in generated
|
|
assert "CHAR" not in generated
|
|
|
|
|
|
def test_concat_with_cast_string() -> None:
|
|
"""
|
|
Test CONCAT with CAST to STRING - verifies the original issue is fixed.
|
|
"""
|
|
sql = """
|
|
SELECT concat(a, cast(b AS string), ' - ')
|
|
FROM "default".c"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
generated = Pinot().generate(expression=ast)
|
|
|
|
# Verify STRING type is preserved (not converted to CHAR)
|
|
assert "STRING" in generated or "string" in generated.lower()
|
|
assert "CHAR" not in generated
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"cast_type, expected_type",
|
|
[
|
|
("INT", "INT"),
|
|
("TINYINT", "INT"),
|
|
("SMALLINT", "INT"),
|
|
("BIGINT", "LONG"),
|
|
("LONG", "LONG"),
|
|
("FLOAT", "FLOAT"),
|
|
("DOUBLE", "DOUBLE"),
|
|
("BOOLEAN", "BOOLEAN"),
|
|
("TIMESTAMP", "TIMESTAMP"),
|
|
("STRING", "STRING"),
|
|
("VARCHAR", "STRING"),
|
|
("CHAR", "STRING"),
|
|
("TEXT", "STRING"),
|
|
("BYTES", "BYTES"),
|
|
("BINARY", "BYTES"),
|
|
("VARBINARY", "BYTES"),
|
|
("JSON", "JSON"),
|
|
],
|
|
)
|
|
def test_type_mappings(cast_type: str, expected_type: str) -> None:
|
|
"""
|
|
Test that Pinot type mappings work correctly for all basic types.
|
|
"""
|
|
sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
generated = Pinot().generate(expression=ast)
|
|
|
|
assert expected_type in generated
|
|
|
|
|
|
def test_unsigned_type() -> None:
|
|
"""
|
|
Test that unsigned integer types are handled correctly.
|
|
Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method.
|
|
"""
|
|
from sqlglot import exp
|
|
|
|
# Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING
|
|
dt = exp.DataType(this=exp.DataType.Type.UBIGINT)
|
|
result = Pinot.Generator().datatype_sql(dt)
|
|
|
|
assert "UNSIGNED" in result
|
|
assert "BIGINT" in result
|
|
|
|
|
|
def test_date_trunc_preserved() -> None:
|
|
"""
|
|
Test that DATE_TRUNC is preserved and not converted to MySQL's DATE() function.
|
|
"""
|
|
sql = "SELECT DATE_TRUNC('day', dt_column) FROM table"
|
|
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
|
|
|
|
assert "DATE_TRUNC" in result
|
|
assert "date_trunc('day'" in result.lower()
|
|
# Should not be converted to MySQL's DATE() function
|
|
assert result != "SELECT DATE(dt_column) FROM table"
|
|
|
|
|
|
def test_cast_timestamp_preserved() -> None:
|
|
"""
|
|
Test that CAST AS TIMESTAMP is preserved and not converted to TIMESTAMP() function.
|
|
"""
|
|
sql = "SELECT CAST(dt_column AS TIMESTAMP) FROM table"
|
|
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
|
|
|
|
assert "CAST" in result
|
|
assert "AS TIMESTAMP" in result
|
|
# Should not be converted to MySQL's TIMESTAMP() function
|
|
assert "TIMESTAMP(dt_column)" not in result
|
|
|
|
|
|
def test_date_trunc_with_cast_timestamp() -> None:
|
|
"""
|
|
Test the original complex query with DATE_TRUNC and CAST AS TIMESTAMP.
|
|
Verifies that both are preserved in parse/generate round-trip.
|
|
"""
|
|
sql = """
|
|
SELECT
|
|
CAST(
|
|
DATE_TRUNC(
|
|
'day',
|
|
CAST(
|
|
DATETIMECONVERT(
|
|
dt_epoch_ms, '1:MILLISECONDS:EPOCH',
|
|
'1:MILLISECONDS:EPOCH', '1:MILLISECONDS'
|
|
) AS TIMESTAMP
|
|
)
|
|
) AS TIMESTAMP
|
|
),
|
|
SUM(a) + SUM(b)
|
|
FROM
|
|
"default".c
|
|
WHERE
|
|
dt_epoch_ms >= 1735690800000
|
|
AND dt_epoch_ms < 1759328588000
|
|
AND locality != 'US'
|
|
GROUP BY
|
|
CAST(
|
|
DATE_TRUNC(
|
|
'day',
|
|
CAST(
|
|
DATETIMECONVERT(
|
|
dt_epoch_ms, '1:MILLISECONDS:EPOCH',
|
|
'1:MILLISECONDS:EPOCH', '1:MILLISECONDS'
|
|
) AS TIMESTAMP
|
|
)
|
|
) AS TIMESTAMP
|
|
)
|
|
LIMIT
|
|
10000
|
|
"""
|
|
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
|
|
|
|
# Verify DATE_TRUNC and CAST are preserved
|
|
assert "DATE_TRUNC" in result
|
|
assert "CAST" in result
|
|
|
|
# Verify these are NOT converted to MySQL functions
|
|
assert "TIMESTAMP(DATETIMECONVERT" not in result
|
|
assert result.count("DATE_TRUNC") == 2 # Should appear twice (SELECT and GROUP BY)
|