mirror of
https://github.com/apache/superset.git
synced 2026-04-08 02:45:22 +00:00
882 lines
26 KiB
Python
882 lines
26 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import pytest
|
|
import sqlglot
|
|
|
|
from superset.sql.dialects.pinot import Pinot
|
|
|
|
|
|
def test_pinot_dialect_registered() -> None:
|
|
"""
|
|
Test that Pinot dialect is properly registered.
|
|
"""
|
|
from superset.sql.parse import SQLGLOT_DIALECTS
|
|
|
|
assert "pinot" in SQLGLOT_DIALECTS
|
|
assert SQLGLOT_DIALECTS["pinot"] == Pinot
|
|
|
|
|
|
def test_double_quotes_as_identifiers() -> None:
|
|
"""
|
|
Test that double quotes are treated as identifiers, not string literals.
|
|
"""
|
|
sql = 'SELECT "column_name" FROM "table_name"'
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"column_name"
|
|
FROM "table_name"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_single_quotes_for_strings() -> None:
|
|
"""
|
|
Test that single quotes are used for string literals.
|
|
"""
|
|
sql = "SELECT * FROM users WHERE name = 'John'"
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
*
|
|
FROM users
|
|
WHERE
|
|
name = 'John'
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_backticks_as_identifiers() -> None:
|
|
"""
|
|
Test that backticks work as identifiers (MySQL-style).
|
|
Backticks are normalized to double quotes in output.
|
|
"""
|
|
sql = "SELECT `column_name` FROM `table_name`"
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"column_name"
|
|
FROM "table_name"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_mixed_identifier_quotes() -> None:
|
|
"""
|
|
Test mixing double quotes and backticks for identifiers.
|
|
All identifiers are normalized to double quotes in output.
|
|
"""
|
|
sql = (
|
|
'SELECT "col1", `col2` FROM "table1" JOIN `table2` ON "table1".id = `table2`.id'
|
|
)
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"col1",
|
|
"col2"
|
|
FROM "table1"
|
|
JOIN "table2"
|
|
ON "table1".id = "table2".id
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_string_with_escaped_quotes() -> None:
|
|
"""
|
|
Test string literals with escaped single quotes.
|
|
"""
|
|
sql = "SELECT * FROM users WHERE name = 'O''Brien'"
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
*
|
|
FROM users
|
|
WHERE
|
|
name = 'O''Brien'
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_string_with_backslash_escape() -> None:
|
|
"""
|
|
Test string literals with backslash escapes.
|
|
"""
|
|
sql = r"SELECT * FROM users WHERE path = 'C:\\Users\\John'"
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
generated = Pinot().generate(expression=ast, pretty=True)
|
|
assert "WHERE" in generated
|
|
assert "path" in generated
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sql, expected",
|
|
[
|
|
(
|
|
'SELECT COUNT(*) FROM "events" WHERE "type" = \'click\'',
|
|
"""
|
|
SELECT
|
|
COUNT(*)
|
|
FROM "events"
|
|
WHERE
|
|
"type" = 'click'
|
|
""".strip(),
|
|
),
|
|
(
|
|
'SELECT "user_id", SUM("amount") FROM "transactions" GROUP BY "user_id"',
|
|
"""
|
|
SELECT
|
|
"user_id",
|
|
SUM("amount")
|
|
FROM "transactions"
|
|
GROUP BY
|
|
"user_id"
|
|
""".strip(),
|
|
),
|
|
(
|
|
"SELECT * FROM \"orders\" WHERE \"status\" IN ('pending', 'shipped')",
|
|
"""
|
|
SELECT
|
|
*
|
|
FROM "orders"
|
|
WHERE
|
|
"status" IN ('pending', 'shipped')
|
|
""".strip(),
|
|
),
|
|
],
|
|
)
|
|
def test_various_queries(sql: str, expected: str) -> None:
|
|
"""
|
|
Test various SQL queries with Pinot dialect.
|
|
"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
assert Pinot().generate(expression=ast, pretty=True) == expected
|
|
|
|
|
|
def test_aggregate_functions() -> None:
|
|
"""
|
|
Test aggregate functions with quoted identifiers.
|
|
"""
|
|
sql = """
|
|
SELECT
|
|
"category",
|
|
COUNT(*),
|
|
AVG("price"),
|
|
MAX("quantity")
|
|
FROM "products"
|
|
GROUP BY "category"
|
|
"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"category",
|
|
COUNT(*),
|
|
AVG("price"),
|
|
MAX("quantity")
|
|
FROM "products"
|
|
GROUP BY
|
|
"category"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_join_with_quoted_identifiers() -> None:
|
|
"""
|
|
Test JOIN operations with double-quoted identifiers.
|
|
"""
|
|
sql = """
|
|
SELECT "u"."name", "o"."total"
|
|
FROM "users" AS "u"
|
|
JOIN "orders" AS "o" ON "u"."id" = "o"."user_id"
|
|
"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"u"."name",
|
|
"o"."total"
|
|
FROM "users" AS "u"
|
|
JOIN "orders" AS "o"
|
|
ON "u"."id" = "o"."user_id"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_subquery_with_quoted_identifiers() -> None:
|
|
"""
|
|
Test subqueries with double-quoted identifiers.
|
|
"""
|
|
sql = 'SELECT * FROM (SELECT "id", "name" FROM "users") AS "subquery"'
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
*
|
|
FROM (
|
|
SELECT
|
|
"id",
|
|
"name"
|
|
FROM "users"
|
|
) AS "subquery"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_case_expression() -> None:
|
|
"""
|
|
Test CASE expressions with quoted identifiers.
|
|
"""
|
|
sql = """
|
|
SELECT "name",
|
|
CASE WHEN "age" < 18 THEN 'minor'
|
|
WHEN "age" >= 18 THEN 'adult'
|
|
END AS "category"
|
|
FROM "persons"
|
|
"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
generated = Pinot().generate(expression=ast, pretty=True)
|
|
assert '"name"' in generated
|
|
assert '"age"' in generated
|
|
assert '"category"' in generated
|
|
assert "'minor'" in generated
|
|
assert "'adult'" in generated
|
|
|
|
|
|
def test_cte_with_quoted_identifiers() -> None:
|
|
"""
|
|
Test Common Table Expressions (CTE) with quoted identifiers.
|
|
"""
|
|
sql = """
|
|
WITH "high_value_orders" AS (
|
|
SELECT * FROM "orders" WHERE "total" > 1000
|
|
)
|
|
SELECT "customer_id", COUNT(*) FROM "high_value_orders" GROUP BY "customer_id"
|
|
"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
generated = Pinot().generate(expression=ast, pretty=True)
|
|
assert 'WITH "high_value_orders" AS' in generated
|
|
assert '"orders"' in generated
|
|
assert '"total"' in generated
|
|
assert '"customer_id"' in generated
|
|
|
|
|
|
def test_order_by_with_quoted_identifiers() -> None:
|
|
"""
|
|
Test ORDER BY clause with quoted identifiers.
|
|
SQLGlot explicitly includes ASC in the output.
|
|
"""
|
|
sql = 'SELECT "name", "salary" FROM "employees" ORDER BY "salary" DESC, "name" ASC'
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT
|
|
"name",
|
|
"salary"
|
|
FROM "employees"
|
|
ORDER BY
|
|
"salary" DESC,
|
|
"name" ASC
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_limit_and_offset() -> None:
|
|
"""
|
|
Test LIMIT and OFFSET clauses.
|
|
"""
|
|
sql = 'SELECT * FROM "products" LIMIT 10 OFFSET 20'
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
generated = Pinot().generate(expression=ast, pretty=True)
|
|
assert '"products"' in generated
|
|
assert "LIMIT 10" in generated
|
|
|
|
|
|
def test_distinct() -> None:
|
|
"""
|
|
Test DISTINCT keyword with quoted identifiers.
|
|
"""
|
|
sql = 'SELECT DISTINCT "category" FROM "products"'
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
|
|
assert (
|
|
Pinot().generate(expression=ast, pretty=True)
|
|
== """
|
|
SELECT DISTINCT
|
|
"category"
|
|
FROM "products"
|
|
""".strip()
|
|
)
|
|
|
|
|
|
def test_cast_to_string() -> None:
|
|
"""
|
|
Test that CAST to STRING is preserved (not converted to CHAR).
|
|
"""
|
|
sql = "SELECT CAST(cohort_size AS STRING) FROM table"
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
generated = Pinot().generate(expression=ast)
|
|
|
|
assert "STRING" in generated
|
|
assert "CHAR" not in generated
|
|
|
|
|
|
def test_concat_with_cast_string() -> None:
|
|
"""
|
|
Test CONCAT with CAST to STRING - verifies the original issue is fixed.
|
|
"""
|
|
sql = """
|
|
SELECT concat(a, cast(b AS string), ' - ')
|
|
FROM "default".c"""
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
generated = Pinot().generate(expression=ast)
|
|
|
|
# Verify STRING type is preserved (not converted to CHAR)
|
|
assert "STRING" in generated or "string" in generated.lower()
|
|
assert "CHAR" not in generated
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"cast_type, expected_type",
|
|
[
|
|
("INT", "INT"),
|
|
("TINYINT", "INT"),
|
|
("SMALLINT", "INT"),
|
|
("BIGINT", "LONG"),
|
|
("LONG", "LONG"),
|
|
("FLOAT", "FLOAT"),
|
|
("DOUBLE", "DOUBLE"),
|
|
("BOOLEAN", "BOOLEAN"),
|
|
("TIMESTAMP", "TIMESTAMP"),
|
|
("STRING", "STRING"),
|
|
("VARCHAR", "STRING"),
|
|
("CHAR", "STRING"),
|
|
("TEXT", "STRING"),
|
|
("BYTES", "BYTES"),
|
|
("BINARY", "BYTES"),
|
|
("VARBINARY", "BYTES"),
|
|
("JSON", "JSON"),
|
|
],
|
|
)
|
|
def test_type_mappings(cast_type: str, expected_type: str) -> None:
|
|
"""
|
|
Test that Pinot type mappings work correctly for all basic types.
|
|
"""
|
|
sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608
|
|
ast = sqlglot.parse_one(sql, Pinot)
|
|
generated = Pinot().generate(expression=ast)
|
|
|
|
assert expected_type in generated
|
|
|
|
|
|
def test_unsigned_type() -> None:
|
|
"""
|
|
Test that unsigned integer types are handled correctly.
|
|
Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method.
|
|
"""
|
|
from sqlglot import exp
|
|
|
|
# Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING
|
|
dt = exp.DataType(this=exp.DataType.Type.UBIGINT)
|
|
result = Pinot.Generator().datatype_sql(dt)
|
|
|
|
assert "UNSIGNED" in result
|
|
assert "BIGINT" in result
|
|
|
|
|
|
def test_date_trunc_preserved() -> None:
|
|
"""
|
|
Test that DATE_TRUNC is preserved and not converted to MySQL's DATE() function.
|
|
"""
|
|
sql = "SELECT DATE_TRUNC('day', dt_column) FROM table"
|
|
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
|
|
|
|
assert "DATE_TRUNC" in result
|
|
assert "date_trunc('day'" in result.lower()
|
|
# Should not be converted to MySQL's DATE() function
|
|
assert result != "SELECT DATE(dt_column) FROM table"
|
|
|
|
|
|
def test_cast_timestamp_preserved() -> None:
|
|
"""
|
|
Test that CAST AS TIMESTAMP is preserved and not converted to TIMESTAMP() function.
|
|
"""
|
|
sql = "SELECT CAST(dt_column AS TIMESTAMP) FROM table"
|
|
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
|
|
|
|
assert "CAST" in result
|
|
assert "AS TIMESTAMP" in result
|
|
# Should not be converted to MySQL's TIMESTAMP() function
|
|
assert "TIMESTAMP(dt_column)" not in result
|
|
|
|
|
|
def test_date_trunc_with_cast_timestamp() -> None:
|
|
"""
|
|
Test the original complex query with DATE_TRUNC and CAST AS TIMESTAMP.
|
|
Verifies that both are preserved in parse/generate round-trip.
|
|
"""
|
|
sql = """
|
|
SELECT
|
|
CAST(
|
|
DATE_TRUNC(
|
|
'day',
|
|
CAST(
|
|
DATETIMECONVERT(
|
|
dt_epoch_ms, '1:MILLISECONDS:EPOCH',
|
|
'1:MILLISECONDS:EPOCH', '1:MILLISECONDS'
|
|
) AS TIMESTAMP
|
|
)
|
|
) AS TIMESTAMP
|
|
),
|
|
SUM(a) + SUM(b)
|
|
FROM
|
|
"default".c
|
|
WHERE
|
|
dt_epoch_ms >= 1735690800000
|
|
AND dt_epoch_ms < 1759328588000
|
|
AND locality != 'US'
|
|
GROUP BY
|
|
CAST(
|
|
DATE_TRUNC(
|
|
'day',
|
|
CAST(
|
|
DATETIMECONVERT(
|
|
dt_epoch_ms, '1:MILLISECONDS:EPOCH',
|
|
'1:MILLISECONDS:EPOCH', '1:MILLISECONDS'
|
|
) AS TIMESTAMP
|
|
)
|
|
) AS TIMESTAMP
|
|
)
|
|
LIMIT
|
|
10000
|
|
"""
|
|
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
|
|
|
|
# Verify DATE_TRUNC and CAST are preserved
|
|
assert "DATE_TRUNC" in result
|
|
assert "CAST" in result
|
|
|
|
# Verify these are NOT converted to MySQL functions
|
|
assert "TIMESTAMP(DATETIMECONVERT" not in result
|
|
assert result.count("DATE_TRUNC") == 2 # Should appear twice (SELECT and GROUP BY)
|
|
|
|
|
|
def test_pinot_date_add_parsing() -> None:
|
|
"""
|
|
Test that Pinot's DATE_ADD function with Presto-like syntax can be parsed.
|
|
"""
|
|
from superset.sql.parse import SQLScript
|
|
|
|
sql = """
|
|
SELECT dt_epoch_ms FROM my_table WHERE dt_epoch_ms >= date_add('day', -180, now())
|
|
"""
|
|
script = SQLScript(sql, "pinot")
|
|
assert len(script.statements) == 1
|
|
assert not script.has_mutation()
|
|
|
|
|
|
def test_pinot_date_add_simple() -> None:
|
|
"""
|
|
Test parsing of simple DATE_ADD expressions.
|
|
"""
|
|
test_cases = [
|
|
"date_add('day', -180, now())",
|
|
"DATE_ADD('month', 5, current_timestamp())",
|
|
"date_add('year', 1, my_date_column)",
|
|
]
|
|
|
|
for sql in test_cases:
|
|
parsed = sqlglot.parse_one(sql, Pinot)
|
|
assert parsed is not None
|
|
# Verify that it generates valid SQL
|
|
generated = parsed.sql(dialect=Pinot)
|
|
assert "DATE_ADD" in generated.upper()
|
|
|
|
|
|
def test_pinot_date_add_unit_quoted() -> None:
|
|
"""
|
|
Test that DATE_ADD preserves quotes around the unit argument.
|
|
|
|
Pinot requires the unit to be a quoted string, not an identifier.
|
|
"""
|
|
sql = "dt_epoch_ms >= date_add('day', -180, now())"
|
|
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
|
|
|
|
# The unit should be quoted: 'DAY' not DAY
|
|
assert "DATE_ADD('DAY', -180, NOW())" in result
|
|
assert "DATE_ADD(DAY," not in result
|
|
|
|
|
|
def test_pinot_date_sub_parsing() -> None:
|
|
"""
|
|
Test that Pinot's DATE_SUB function with Presto-like syntax can be parsed.
|
|
"""
|
|
from superset.sql.parse import SQLScript
|
|
|
|
sql = "SELECT * FROM my_table WHERE dt >= date_sub('day', 7, now())"
|
|
script = SQLScript(sql, "pinot")
|
|
assert len(script.statements) == 1
|
|
assert not script.has_mutation()
|
|
|
|
|
|
def test_pinot_date_sub_simple() -> None:
|
|
"""
|
|
Test parsing of simple DATE_SUB expressions.
|
|
"""
|
|
test_cases = [
|
|
"date_sub('day', 7, now())",
|
|
"DATE_SUB('month', 3, current_timestamp())",
|
|
"date_sub('hour', 24, my_date_column)",
|
|
]
|
|
|
|
for sql in test_cases:
|
|
parsed = sqlglot.parse_one(sql, Pinot)
|
|
assert parsed is not None
|
|
# Verify that it generates valid SQL
|
|
generated = parsed.sql(dialect=Pinot)
|
|
assert "DATE_SUB" in generated.upper()
|
|
|
|
|
|
def test_pinot_date_sub_unit_quoted() -> None:
|
|
"""
|
|
Test that DATE_SUB preserves quotes around the unit argument.
|
|
|
|
Pinot requires the unit to be a quoted string, not an identifier.
|
|
"""
|
|
sql = "dt_epoch_ms >= date_sub('day', -180, now())"
|
|
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
|
|
|
|
# The unit should be quoted: 'DAY' not DAY
|
|
assert "DATE_SUB('DAY', -180, NOW())" in result
|
|
assert "DATE_SUB(DAY," not in result
|
|
|
|
|
|
def test_substr_cross_dialect_generation() -> None:
|
|
"""
|
|
Test that SUBSTR is preserved when generating Pinot SQL.
|
|
|
|
Note that the MySQL dialect (in which Pinot is based) uses SUBSTRING instead of
|
|
SUBSTR.
|
|
"""
|
|
# Parse with Pinot dialect
|
|
pinot_sql = "SELECT SUBSTR('hello', 0, 3) FROM users"
|
|
parsed = sqlglot.parse_one(pinot_sql, Pinot)
|
|
|
|
# Generate back to Pinot → should preserve SUBSTR
|
|
pinot_output = parsed.sql(dialect=Pinot)
|
|
assert "SUBSTR(" in pinot_output
|
|
assert "SUBSTRING(" not in pinot_output
|
|
|
|
# Generate to MySQL → should convert to SUBSTRING
|
|
mysql_output = parsed.sql(dialect="mysql")
|
|
assert "SUBSTRING(" in mysql_output
|
|
assert pinot_output != mysql_output # They should be different
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"function_name,sample_args",
|
|
[
|
|
# Math functions
|
|
("ABS", "-5"),
|
|
("CEIL", "3.14"),
|
|
("FLOOR", "3.14"),
|
|
("EXP", "2"),
|
|
("LN", "10"),
|
|
("SQRT", "16"),
|
|
("ROUNDDECIMAL", "3.14159, 2"),
|
|
("ADD", "1, 2, 3"),
|
|
("SUB", "10, 3"),
|
|
("MULT", "5, 4"),
|
|
("MOD", "10, 3"),
|
|
# String functions
|
|
("UPPER", "'hello'"),
|
|
("LOWER", "'HELLO'"),
|
|
("REVERSE", "'hello'"),
|
|
("SUBSTR", "'hello', 0, 3"),
|
|
("CONCAT", "'hello', ' ', 'world'"),
|
|
("TRIM", "' hello '"),
|
|
("LTRIM", "' hello'"),
|
|
("RTRIM", "'hello '"),
|
|
("LENGTH", "'hello'"),
|
|
("STRPOS", "'hello', 'l', 1"),
|
|
("STARTSWITH", "'hello', 'he'"),
|
|
("REPLACE", "'hello', 'l', 'r'"),
|
|
("RPAD", "'hello', 10, 'x'"),
|
|
("LPAD", "'hello', 10, 'x'"),
|
|
("CODEPOINT", "'A'"),
|
|
("CHR", "65"),
|
|
("regexpExtract", "'foo123bar', '[0-9]+'"),
|
|
("regexpReplace", "'hello', 'l', 'r'"),
|
|
("remove", "'hello', 'l'"),
|
|
("urlEncoding", "'hello world'"),
|
|
("urlDecoding", "'hello%20world'"),
|
|
("fromBase64", "'aGVsbG8='"),
|
|
("toUtf8", "'hello'"),
|
|
("isSubnetOf", "'192.168.1.1', '192.168.0.0/16'"),
|
|
# DateTime functions
|
|
("DATETRUNC", "'day', timestamp_col"),
|
|
("DATETIMECONVERT", "dt_col, '1:HOURS:EPOCH', '1:DAYS:EPOCH', '1:DAYS'"),
|
|
("TIMECONVERT", "timestamp_col, 'MILLISECONDS', 'SECONDS'"),
|
|
("NOW", ""),
|
|
("AGO", "'P1D'"),
|
|
("YEAR", "timestamp_col"),
|
|
("QUARTER", "timestamp_col"),
|
|
("MONTH", "timestamp_col"),
|
|
("WEEK", "timestamp_col"),
|
|
("DAY", "timestamp_col"),
|
|
("HOUR", "timestamp_col"),
|
|
("MINUTE", "timestamp_col"),
|
|
("SECOND", "timestamp_col"),
|
|
("MILLISECOND", "timestamp_col"),
|
|
("DAYOFWEEK", "timestamp_col"),
|
|
("DAYOFYEAR", "timestamp_col"),
|
|
("YEAROFWEEK", "timestamp_col"),
|
|
("toEpochSeconds", "timestamp_col"),
|
|
("toEpochMinutes", "timestamp_col"),
|
|
("toEpochHours", "timestamp_col"),
|
|
("toEpochDays", "timestamp_col"),
|
|
("fromEpochSeconds", "1234567890"),
|
|
("fromEpochMinutes", "20576131"),
|
|
("fromEpochHours", "342935"),
|
|
("fromEpochDays", "14288"),
|
|
("toDateTime", "timestamp_col, 'yyyy-MM-dd'"),
|
|
("fromDateTime", "'2024-01-01', 'yyyy-MM-dd'"),
|
|
("timezoneHour", "timestamp_col"),
|
|
("timezoneMinute", "timestamp_col"),
|
|
("DATE_ADD", "'day', 7, NOW()"),
|
|
("DATE_SUB", "'day', 7, NOW()"),
|
|
("TIMESTAMPADD", "'day', 7, timestamp_col"),
|
|
("TIMESTAMPDIFF", "'day', timestamp1, timestamp2"),
|
|
("dateTrunc", "'day', timestamp_col"),
|
|
("dateDiff", "'day', timestamp1, timestamp2"),
|
|
("dateAdd", "'day', 7, timestamp_col"),
|
|
("dateBin", "'day', timestamp_col, NOW()"),
|
|
("toIso8601", "timestamp_col"),
|
|
("fromIso8601", "'2024-01-01T00:00:00Z'"),
|
|
# Aggregation functions
|
|
("COUNT", "*"),
|
|
("SUM", "amount"),
|
|
("AVG", "value"),
|
|
("MIN", "value"),
|
|
("MAX", "value"),
|
|
("DISTINCTCOUNT", "user_id"),
|
|
("DISTINCTCOUNTBITMAP", "user_id"),
|
|
("DISTINCTCOUNTHLL", "user_id"),
|
|
("DISTINCTCOUNTRAWHLL", "user_id"),
|
|
("DISTINCTCOUNTHLLPLUS", "user_id"),
|
|
("DISTINCTCOUNTRAWHLLPLUS", "user_id"),
|
|
("DISTINCTCOUNTSMARTHLL", "user_id"),
|
|
("DISTINCTCOUNTCPCSKETCH", "user_id"),
|
|
("DISTINCTCOUNTRAWCPCSKETCH", "user_id"),
|
|
("DISTINCTCOUNTTHETASKETCH", "user_id"),
|
|
("DISTINCTCOUNTRAWTHETASKETCH", "user_id"),
|
|
("DISTINCTCOUNTTUPLESKETCH", "user_id"),
|
|
("DISTINCTCOUNTRAWINTEGERSUMTUPLESKETCH", "user_id"),
|
|
("DISTINCTCOUNTULL", "user_id"),
|
|
("DISTINCTCOUNTRAWULL", "user_id"),
|
|
("SEGMENTPARTITIONEDDISTINCTCOUNT", "user_id"),
|
|
("SUMVALUESINTEGERSUMTUPLESKETCH", "value"),
|
|
("PERCENTILE", "value, 95"),
|
|
("PERCENTILEEST", "value, 95"),
|
|
("PERCENTILETDIGEST", "value, 95"),
|
|
("PERCENTILESMARTTDIGEST", "value, 95"),
|
|
("PERCENTILEKLL", "value, 95"),
|
|
("PERCENTILEKLLRAW", "value, 95"),
|
|
("HISTOGRAM", "value, 10"),
|
|
("MODE", "category"),
|
|
("MINMAXRANGE", "value"),
|
|
("SUMPRECISION", "value, 10"),
|
|
("ARG_MIN", "value, id"),
|
|
("ARG_MAX", "value, id"),
|
|
("COVAR_POP", "x, y"),
|
|
("COVAR_SAMP", "x, y"),
|
|
("LASTWITHTIME", "value, timestamp_col, 'LONG'"),
|
|
("FIRSTWITHTIME", "value, timestamp_col, 'LONG'"),
|
|
("ARRAY_AGG", "value"),
|
|
# Multi-value functions
|
|
("COUNTMV", "tags"),
|
|
("MAXMV", "scores"),
|
|
("MINMV", "scores"),
|
|
("SUMMV", "scores"),
|
|
("AVGMV", "scores"),
|
|
("MINMAXRANGEMV", "scores"),
|
|
("PERCENTILEMV", "scores, 95"),
|
|
("PERCENTILEESTMV", "scores, 95"),
|
|
("PERCENTILETDIGESTMV", "scores, 95"),
|
|
("PERCENTILEKLLMV", "scores, 95"),
|
|
("DISTINCTCOUNTMV", "tags"),
|
|
("DISTINCTCOUNTBITMAPMV", "tags"),
|
|
("DISTINCTCOUNTHLLMV", "tags"),
|
|
("DISTINCTCOUNTRAWHLLMV", "tags"),
|
|
("DISTINCTCOUNTHLLPLUSMV", "tags"),
|
|
("DISTINCTCOUNTRAWHLLPLUSMV", "tags"),
|
|
("ARRAYLENGTH", "array_col"),
|
|
("MAP_VALUE", "map_col, 'key'"),
|
|
("VALUEIN", "value, 'val1', 'val2'"),
|
|
# JSON functions
|
|
("JSONEXTRACTSCALAR", "json_col, '$.name', 'STRING'"),
|
|
("JSONEXTRACTKEY", "json_col, '$.data'"),
|
|
("JSONFORMAT", "json_col"),
|
|
("JSONPATH", "json_col, '$.name'"),
|
|
("JSONPATHLONG", "json_col, '$.id'"),
|
|
("JSONPATHDOUBLE", "json_col, '$.price'"),
|
|
("JSONPATHSTRING", "json_col, '$.name'"),
|
|
("JSONPATHARRAY", "json_col, '$.items'"),
|
|
("JSONPATHARRAYDEFAULTEMPTY", "json_col, '$.items'"),
|
|
("TOJSONMAPSTR", "map_col"),
|
|
("JSON_MATCH", "json_col, '\"$.name\"=''value'''"),
|
|
("JSON_EXTRACT_SCALAR", "json_col, '$.name', 'STRING'"),
|
|
# Array functions
|
|
("arrayReverseInt", "int_array"),
|
|
("arrayReverseString", "string_array"),
|
|
("arraySortInt", "int_array"),
|
|
("arraySortString", "string_array"),
|
|
("arrayIndexOfInt", "int_array, 5"),
|
|
("arrayIndexOfString", "string_array, 'value'"),
|
|
("arrayContainsInt", "int_array, 5"),
|
|
("arrayContainsString", "string_array, 'value'"),
|
|
("arraySliceInt", "int_array, 0, 3"),
|
|
("arraySliceString", "string_array, 0, 3"),
|
|
("arrayDistinctInt", "int_array"),
|
|
("arrayDistinctString", "string_array"),
|
|
("arrayRemoveInt", "int_array, 5"),
|
|
("arrayRemoveString", "string_array, 'value'"),
|
|
("arrayUnionInt", "int_array1, int_array2"),
|
|
("arrayUnionString", "string_array1, string_array2"),
|
|
("arrayConcatInt", "int_array1, int_array2"),
|
|
("arrayConcatString", "string_array1, string_array2"),
|
|
("arrayElementAtInt", "int_array, 0"),
|
|
("arrayElementAtString", "string_array, 0"),
|
|
("arraySumInt", "int_array"),
|
|
("arrayValueConstructor", "1, 2, 3"),
|
|
("arrayToString", "array_col, ','"),
|
|
# Geospatial functions
|
|
("ST_DISTANCE", "point1, point2"),
|
|
("ST_CONTAINS", "polygon, point"),
|
|
("ST_AREA", "polygon"),
|
|
("ST_GEOMFROMTEXT", "'POINT(1 2)'"),
|
|
("ST_GEOMFROMWKB", "wkb_col"),
|
|
("ST_GEOGFROMWKB", "wkb_col"),
|
|
("ST_GEOGFROMTEXT", "'POINT(1 2)'"),
|
|
("ST_POINT", "1.0, 2.0"),
|
|
("ST_POLYGON", "'POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'"),
|
|
("ST_ASBINARY", "geom_col"),
|
|
("ST_ASTEXT", "geom_col"),
|
|
("ST_GEOMETRYTYPE", "geom_col"),
|
|
("ST_EQUALS", "geom1, geom2"),
|
|
("ST_WITHIN", "geom1, geom2"),
|
|
("ST_UNION", "geom1, geom2"),
|
|
("ST_GEOMFROMGEOJSON", '\'{"type":"Point","coordinates":[1,2]}\''),
|
|
("ST_GEOGFROMGEOJSON", '\'{"type":"Point","coordinates":[1,2]}\''),
|
|
("ST_ASGEOJSON", "geom_col"),
|
|
("toSphericalGeography", "geom_col"),
|
|
("toGeometry", "geog_col"),
|
|
# Binary/Hash functions
|
|
("SHA", "'hello'"),
|
|
("SHA256", "'hello'"),
|
|
("SHA512", "'hello'"),
|
|
("SHA224", "'hello'"),
|
|
("MD5", "'hello'"),
|
|
("MD2", "'hello'"),
|
|
("toBase64", "'hello'"),
|
|
("fromUtf8", "bytes_col"),
|
|
("MurmurHash2", "'hello'"),
|
|
("MurmurHash3Bit32", "'hello'"),
|
|
# Window functions
|
|
("ROW_NUMBER", ""),
|
|
("RANK", ""),
|
|
("DENSE_RANK", ""),
|
|
# Funnel analysis
|
|
("FunnelMaxStep", "event_col, 'step1', 'step2', 'step3'"),
|
|
("FunnelMatchStep", "event_col, 'step1', 'step2', 'step3'"),
|
|
("FunnelCompleteCount", "event_col, 'step1', 'step2', 'step3'"),
|
|
# Text search
|
|
("TEXT_MATCH", "text_col, 'search query'"),
|
|
# Vector functions
|
|
("VECTOR_SIMILARITY", "vector1, vector2"),
|
|
("l2_distance", "vector1, vector2"),
|
|
# Lookup
|
|
("LOOKUP", "'lookupTable', 'lookupColumn', 'keyColumn', keyValue"),
|
|
# URL functions
|
|
("urlProtocol", "'https://example.com/path'"),
|
|
("urlDomain", "'https://example.com/path'"),
|
|
("urlPath", "'https://example.com/path'"),
|
|
("urlPort", "'https://example.com:8080/path'"),
|
|
("urlEncode", "'hello world'"),
|
|
("urlDecode", "'hello%20world'"),
|
|
# Conditional
|
|
("COALESCE", "val1, val2, 'default'"),
|
|
("NULLIF", "val1, val2"),
|
|
("GREATEST", "1, 2, 3"),
|
|
("LEAST", "1, 2, 3"),
|
|
# Other
|
|
("REGEXP_LIKE", "'hello', 'h.*'"),
|
|
("GROOVY", "'{return arg0 + arg1}', col1, col2"),
|
|
],
|
|
)
|
|
def test_pinot_function_names_preserved(function_name: str, sample_args: str) -> None:
|
|
"""
|
|
Test that Pinot function names are preserved during parse/generate roundtrip.
|
|
|
|
This ensures that when we parse Pinot SQL and generate it back, the function
|
|
names remain unchanged. This is critical for maintaining compatibility with
|
|
Pinot's function library.
|
|
"""
|
|
# Special handling for window functions
|
|
if function_name in ["ROW_NUMBER", "RANK", "DENSE_RANK"]:
|
|
sql = f"SELECT {function_name}() OVER (ORDER BY col) FROM table" # noqa: S608
|
|
else:
|
|
sql = f"SELECT {function_name}({sample_args}) FROM table" # noqa: S608
|
|
|
|
# Parse with Pinot dialect
|
|
parsed = sqlglot.parse_one(sql, Pinot)
|
|
|
|
# Generate back to Pinot
|
|
generated = parsed.sql(dialect=Pinot)
|
|
|
|
# The function name should be preserved (case-insensitive check)
|
|
assert function_name.upper() in generated.upper(), (
|
|
f"Function {function_name} not preserved in output: {generated}"
|
|
)
|