Files
superset2/tests/unit_tests/sql/dialects/pinot_tests.py
Beto Dealmeida aa97d2fe03 fix(pinot): dialect date truncation (#35420)
Co-authored-by: bito-code-review[bot] <188872107+bito-code-review[bot]@users.noreply.github.com>
2025-10-01 13:16:46 -04:00

501 lines
12 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import sqlglot
from superset.sql.dialects.pinot import Pinot
def test_pinot_dialect_registered() -> None:
"""Test that Pinot dialect is properly registered."""
from superset.sql.parse import SQLGLOT_DIALECTS
assert "pinot" in SQLGLOT_DIALECTS
assert SQLGLOT_DIALECTS["pinot"] == Pinot
def test_double_quotes_as_identifiers() -> None:
"""
Test that double quotes are treated as identifiers, not string literals.
"""
sql = 'SELECT "column_name" FROM "table_name"'
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"column_name"
FROM "table_name"
""".strip()
)
def test_single_quotes_for_strings() -> None:
"""
Test that single quotes are used for string literals.
"""
sql = "SELECT * FROM users WHERE name = 'John'"
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
*
FROM users
WHERE
name = 'John'
""".strip()
)
def test_backticks_as_identifiers() -> None:
"""
Test that backticks work as identifiers (MySQL-style).
Backticks are normalized to double quotes in output.
"""
sql = "SELECT `column_name` FROM `table_name`"
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"column_name"
FROM "table_name"
""".strip()
)
def test_mixed_identifier_quotes() -> None:
"""
Test mixing double quotes and backticks for identifiers.
All identifiers are normalized to double quotes in output.
"""
sql = (
'SELECT "col1", `col2` FROM "table1" JOIN `table2` ON "table1".id = `table2`.id'
)
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"col1",
"col2"
FROM "table1"
JOIN "table2"
ON "table1".id = "table2".id
""".strip()
)
def test_string_with_escaped_quotes() -> None:
"""
Test string literals with escaped single quotes.
"""
sql = "SELECT * FROM users WHERE name = 'O''Brien'"
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
*
FROM users
WHERE
name = 'O''Brien'
""".strip()
)
def test_string_with_backslash_escape() -> None:
"""
Test string literals with backslash escapes.
"""
sql = r"SELECT * FROM users WHERE path = 'C:\\Users\\John'"
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast, pretty=True)
assert "WHERE" in generated
assert "path" in generated
@pytest.mark.parametrize(
"sql, expected",
[
(
'SELECT COUNT(*) FROM "events" WHERE "type" = \'click\'',
"""
SELECT
COUNT(*)
FROM "events"
WHERE
"type" = 'click'
""".strip(),
),
(
'SELECT "user_id", SUM("amount") FROM "transactions" GROUP BY "user_id"',
"""
SELECT
"user_id",
SUM("amount")
FROM "transactions"
GROUP BY
"user_id"
""".strip(),
),
(
"SELECT * FROM \"orders\" WHERE \"status\" IN ('pending', 'shipped')",
"""
SELECT
*
FROM "orders"
WHERE
"status" IN ('pending', 'shipped')
""".strip(),
),
],
)
def test_various_queries(sql: str, expected: str) -> None:
"""
Test various SQL queries with Pinot dialect.
"""
ast = sqlglot.parse_one(sql, Pinot)
assert Pinot().generate(expression=ast, pretty=True) == expected
def test_aggregate_functions() -> None:
"""
Test aggregate functions with quoted identifiers.
"""
sql = """
SELECT
"category",
COUNT(*),
AVG("price"),
MAX("quantity")
FROM "products"
GROUP BY "category"
"""
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"category",
COUNT(*),
AVG("price"),
MAX("quantity")
FROM "products"
GROUP BY
"category"
""".strip()
)
def test_join_with_quoted_identifiers() -> None:
"""
Test JOIN operations with double-quoted identifiers.
"""
sql = """
SELECT "u"."name", "o"."total"
FROM "users" AS "u"
JOIN "orders" AS "o" ON "u"."id" = "o"."user_id"
"""
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"u"."name",
"o"."total"
FROM "users" AS "u"
JOIN "orders" AS "o"
ON "u"."id" = "o"."user_id"
""".strip()
)
def test_subquery_with_quoted_identifiers() -> None:
"""
Test subqueries with double-quoted identifiers.
"""
sql = 'SELECT * FROM (SELECT "id", "name" FROM "users") AS "subquery"'
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
*
FROM (
SELECT
"id",
"name"
FROM "users"
) AS "subquery"
""".strip()
)
def test_case_expression() -> None:
"""
Test CASE expressions with quoted identifiers.
"""
sql = """
SELECT "name",
CASE WHEN "age" < 18 THEN 'minor'
WHEN "age" >= 18 THEN 'adult'
END AS "category"
FROM "persons"
"""
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast, pretty=True)
assert '"name"' in generated
assert '"age"' in generated
assert '"category"' in generated
assert "'minor'" in generated
assert "'adult'" in generated
def test_cte_with_quoted_identifiers() -> None:
"""
Test Common Table Expressions (CTE) with quoted identifiers.
"""
sql = """
WITH "high_value_orders" AS (
SELECT * FROM "orders" WHERE "total" > 1000
)
SELECT "customer_id", COUNT(*) FROM "high_value_orders" GROUP BY "customer_id"
"""
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast, pretty=True)
assert 'WITH "high_value_orders" AS' in generated
assert '"orders"' in generated
assert '"total"' in generated
assert '"customer_id"' in generated
def test_order_by_with_quoted_identifiers() -> None:
"""
Test ORDER BY clause with quoted identifiers.
SQLGlot explicitly includes ASC in the output.
"""
sql = 'SELECT "name", "salary" FROM "employees" ORDER BY "salary" DESC, "name" ASC'
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"name",
"salary"
FROM "employees"
ORDER BY
"salary" DESC,
"name" ASC
""".strip()
)
def test_limit_and_offset() -> None:
"""
Test LIMIT and OFFSET clauses.
"""
sql = 'SELECT * FROM "products" LIMIT 10 OFFSET 20'
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast, pretty=True)
assert '"products"' in generated
assert "LIMIT 10" in generated
def test_distinct() -> None:
"""
Test DISTINCT keyword with quoted identifiers.
"""
sql = 'SELECT DISTINCT "category" FROM "products"'
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT DISTINCT
"category"
FROM "products"
""".strip()
)
def test_cast_to_string() -> None:
"""
Test that CAST to STRING is preserved (not converted to CHAR).
"""
sql = "SELECT CAST(cohort_size AS STRING) FROM table"
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)
assert "STRING" in generated
assert "CHAR" not in generated
def test_concat_with_cast_string() -> None:
"""
Test CONCAT with CAST to STRING - verifies the original issue is fixed.
"""
sql = """
SELECT concat(a, cast(b AS string), ' - ')
FROM "default".c"""
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)
# Verify STRING type is preserved (not converted to CHAR)
assert "STRING" in generated or "string" in generated.lower()
assert "CHAR" not in generated
@pytest.mark.parametrize(
"cast_type, expected_type",
[
("INT", "INT"),
("TINYINT", "INT"),
("SMALLINT", "INT"),
("BIGINT", "LONG"),
("LONG", "LONG"),
("FLOAT", "FLOAT"),
("DOUBLE", "DOUBLE"),
("BOOLEAN", "BOOLEAN"),
("TIMESTAMP", "TIMESTAMP"),
("STRING", "STRING"),
("VARCHAR", "STRING"),
("CHAR", "STRING"),
("TEXT", "STRING"),
("BYTES", "BYTES"),
("BINARY", "BYTES"),
("VARBINARY", "BYTES"),
("JSON", "JSON"),
],
)
def test_type_mappings(cast_type: str, expected_type: str) -> None:
"""
Test that Pinot type mappings work correctly for all basic types.
"""
sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)
assert expected_type in generated
def test_unsigned_type() -> None:
"""
Test that unsigned integer types are handled correctly.
Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method.
"""
from sqlglot import exp
# Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING
dt = exp.DataType(this=exp.DataType.Type.UBIGINT)
result = Pinot.Generator().datatype_sql(dt)
assert "UNSIGNED" in result
assert "BIGINT" in result
def test_date_trunc_preserved() -> None:
"""
Test that DATE_TRUNC is preserved and not converted to MySQL's DATE() function.
"""
sql = "SELECT DATE_TRUNC('day', dt_column) FROM table"
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
assert "DATE_TRUNC" in result
assert "date_trunc('day'" in result.lower()
# Should not be converted to MySQL's DATE() function
assert result != "SELECT DATE(dt_column) FROM table"
def test_cast_timestamp_preserved() -> None:
"""
Test that CAST AS TIMESTAMP is preserved and not converted to TIMESTAMP() function.
"""
sql = "SELECT CAST(dt_column AS TIMESTAMP) FROM table"
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
assert "CAST" in result
assert "AS TIMESTAMP" in result
# Should not be converted to MySQL's TIMESTAMP() function
assert "TIMESTAMP(dt_column)" not in result
def test_date_trunc_with_cast_timestamp() -> None:
"""
Test the original complex query with DATE_TRUNC and CAST AS TIMESTAMP.
Verifies that both are preserved in parse/generate round-trip.
"""
sql = """
SELECT
CAST(
DATE_TRUNC(
'day',
CAST(
DATETIMECONVERT(
dt_epoch_ms, '1:MILLISECONDS:EPOCH',
'1:MILLISECONDS:EPOCH', '1:MILLISECONDS'
) AS TIMESTAMP
)
) AS TIMESTAMP
),
SUM(a) + SUM(b)
FROM
"default".c
WHERE
dt_epoch_ms >= 1735690800000
AND dt_epoch_ms < 1759328588000
AND locality != 'US'
GROUP BY
CAST(
DATE_TRUNC(
'day',
CAST(
DATETIMECONVERT(
dt_epoch_ms, '1:MILLISECONDS:EPOCH',
'1:MILLISECONDS:EPOCH', '1:MILLISECONDS'
) AS TIMESTAMP
)
) AS TIMESTAMP
)
LIMIT
10000
"""
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
# Verify DATE_TRUNC and CAST are preserved
assert "DATE_TRUNC" in result
assert "CAST" in result
# Verify these are NOT converted to MySQL functions
assert "TIMESTAMP(DATETIMECONVERT" not in result
assert result.count("DATE_TRUNC") == 2 # Should appear twice (SELECT and GROUP BY)