# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest import sqlglot from superset.sql.dialects.pinot import Pinot def test_pinot_dialect_registered() -> None: """ Test that Pinot dialect is properly registered. """ from superset.sql.parse import SQLGLOT_DIALECTS assert "pinot" in SQLGLOT_DIALECTS assert SQLGLOT_DIALECTS["pinot"] == Pinot def test_double_quotes_as_identifiers() -> None: """ Test that double quotes are treated as identifiers, not string literals. """ sql = 'SELECT "column_name" FROM "table_name"' ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "column_name" FROM "table_name" """.strip() ) def test_single_quotes_for_strings() -> None: """ Test that single quotes are used for string literals. """ sql = "SELECT * FROM users WHERE name = 'John'" ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT * FROM users WHERE name = 'John' """.strip() ) def test_backticks_as_identifiers() -> None: """ Test that backticks work as identifiers (MySQL-style). Backticks are normalized to double quotes in output. """ sql = "SELECT `column_name` FROM `table_name`" ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "column_name" FROM "table_name" """.strip() ) def test_mixed_identifier_quotes() -> None: """ Test mixing double quotes and backticks for identifiers. All identifiers are normalized to double quotes in output. """ sql = ( 'SELECT "col1", `col2` FROM "table1" JOIN `table2` ON "table1".id = `table2`.id' ) ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "col1", "col2" FROM "table1" JOIN "table2" ON "table1".id = "table2".id """.strip() ) def test_string_with_escaped_quotes() -> None: """ Test string literals with escaped single quotes. """ sql = "SELECT * FROM users WHERE name = 'O''Brien'" ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT * FROM users WHERE name = 'O''Brien' """.strip() ) def test_string_with_backslash_escape() -> None: """ Test string literals with backslash escapes. """ sql = r"SELECT * FROM users WHERE path = 'C:\\Users\\John'" ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast, pretty=True) assert "WHERE" in generated assert "path" in generated @pytest.mark.parametrize( "sql, expected", [ ( 'SELECT COUNT(*) FROM "events" WHERE "type" = \'click\'', """ SELECT COUNT(*) FROM "events" WHERE "type" = 'click' """.strip(), ), ( 'SELECT "user_id", SUM("amount") FROM "transactions" GROUP BY "user_id"', """ SELECT "user_id", SUM("amount") FROM "transactions" GROUP BY "user_id" """.strip(), ), ( "SELECT * FROM \"orders\" WHERE \"status\" IN ('pending', 'shipped')", """ SELECT * FROM "orders" WHERE "status" IN ('pending', 'shipped') """.strip(), ), ], ) def test_various_queries(sql: str, expected: str) -> None: """ Test various SQL queries with Pinot dialect. """ ast = sqlglot.parse_one(sql, Pinot) assert Pinot().generate(expression=ast, pretty=True) == expected def test_aggregate_functions() -> None: """ Test aggregate functions with quoted identifiers. """ sql = """ SELECT "category", COUNT(*), AVG("price"), MAX("quantity") FROM "products" GROUP BY "category" """ ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "category", COUNT(*), AVG("price"), MAX("quantity") FROM "products" GROUP BY "category" """.strip() ) def test_join_with_quoted_identifiers() -> None: """ Test JOIN operations with double-quoted identifiers. """ sql = """ SELECT "u"."name", "o"."total" FROM "users" AS "u" JOIN "orders" AS "o" ON "u"."id" = "o"."user_id" """ ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "u"."name", "o"."total" FROM "users" AS "u" JOIN "orders" AS "o" ON "u"."id" = "o"."user_id" """.strip() ) def test_subquery_with_quoted_identifiers() -> None: """ Test subqueries with double-quoted identifiers. """ sql = 'SELECT * FROM (SELECT "id", "name" FROM "users") AS "subquery"' ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT * FROM ( SELECT "id", "name" FROM "users" ) AS "subquery" """.strip() ) def test_case_expression() -> None: """ Test CASE expressions with quoted identifiers. """ sql = """ SELECT "name", CASE WHEN "age" < 18 THEN 'minor' WHEN "age" >= 18 THEN 'adult' END AS "category" FROM "persons" """ ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast, pretty=True) assert '"name"' in generated assert '"age"' in generated assert '"category"' in generated assert "'minor'" in generated assert "'adult'" in generated def test_cte_with_quoted_identifiers() -> None: """ Test Common Table Expressions (CTE) with quoted identifiers. """ sql = """ WITH "high_value_orders" AS ( SELECT * FROM "orders" WHERE "total" > 1000 ) SELECT "customer_id", COUNT(*) FROM "high_value_orders" GROUP BY "customer_id" """ ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast, pretty=True) assert 'WITH "high_value_orders" AS' in generated assert '"orders"' in generated assert '"total"' in generated assert '"customer_id"' in generated def test_order_by_with_quoted_identifiers() -> None: """ Test ORDER BY clause with quoted identifiers. SQLGlot explicitly includes ASC in the output. """ sql = 'SELECT "name", "salary" FROM "employees" ORDER BY "salary" DESC, "name" ASC' ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "name", "salary" FROM "employees" ORDER BY "salary" DESC, "name" ASC """.strip() ) def test_limit_and_offset() -> None: """ Test LIMIT and OFFSET clauses. """ sql = 'SELECT * FROM "products" LIMIT 10 OFFSET 20' ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast, pretty=True) assert '"products"' in generated assert "LIMIT 10" in generated def test_distinct() -> None: """ Test DISTINCT keyword with quoted identifiers. """ sql = 'SELECT DISTINCT "category" FROM "products"' ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT DISTINCT "category" FROM "products" """.strip() ) def test_cast_to_string() -> None: """ Test that CAST to STRING is preserved (not converted to CHAR). """ sql = "SELECT CAST(cohort_size AS STRING) FROM table" ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast) assert "STRING" in generated assert "CHAR" not in generated def test_concat_with_cast_string() -> None: """ Test CONCAT with CAST to STRING - verifies the original issue is fixed. """ sql = """ SELECT concat(a, cast(b AS string), ' - ') FROM "default".c""" ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast) # Verify STRING type is preserved (not converted to CHAR) assert "STRING" in generated or "string" in generated.lower() assert "CHAR" not in generated @pytest.mark.parametrize( "cast_type, expected_type", [ ("INT", "INT"), ("TINYINT", "INT"), ("SMALLINT", "INT"), ("BIGINT", "LONG"), ("LONG", "LONG"), ("FLOAT", "FLOAT"), ("DOUBLE", "DOUBLE"), ("BOOLEAN", "BOOLEAN"), ("TIMESTAMP", "TIMESTAMP"), ("STRING", "STRING"), ("VARCHAR", "STRING"), ("CHAR", "STRING"), ("TEXT", "STRING"), ("BYTES", "BYTES"), ("BINARY", "BYTES"), ("VARBINARY", "BYTES"), ("JSON", "JSON"), ], ) def test_type_mappings(cast_type: str, expected_type: str) -> None: """ Test that Pinot type mappings work correctly for all basic types. """ sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608 ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast) assert expected_type in generated def test_unsigned_type() -> None: """ Test that unsigned integer types are handled correctly. Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method. """ from sqlglot import exp # Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING dt = exp.DataType(this=exp.DataType.Type.UBIGINT) result = Pinot.Generator().datatype_sql(dt) assert "UNSIGNED" in result assert "BIGINT" in result def test_date_trunc_preserved() -> None: """ Test that DATE_TRUNC is preserved and not converted to MySQL's DATE() function. """ sql = "SELECT DATE_TRUNC('day', dt_column) FROM table" result = sqlglot.parse_one(sql, Pinot).sql(Pinot) assert "DATE_TRUNC" in result assert "date_trunc('day'" in result.lower() # Should not be converted to MySQL's DATE() function assert result != "SELECT DATE(dt_column) FROM table" def test_cast_timestamp_preserved() -> None: """ Test that CAST AS TIMESTAMP is preserved and not converted to TIMESTAMP() function. """ sql = "SELECT CAST(dt_column AS TIMESTAMP) FROM table" result = sqlglot.parse_one(sql, Pinot).sql(Pinot) assert "CAST" in result assert "AS TIMESTAMP" in result # Should not be converted to MySQL's TIMESTAMP() function assert "TIMESTAMP(dt_column)" not in result def test_date_trunc_with_cast_timestamp() -> None: """ Test the original complex query with DATE_TRUNC and CAST AS TIMESTAMP. Verifies that both are preserved in parse/generate round-trip. """ sql = """ SELECT CAST( DATE_TRUNC( 'day', CAST( DATETIMECONVERT( dt_epoch_ms, '1:MILLISECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '1:MILLISECONDS' ) AS TIMESTAMP ) ) AS TIMESTAMP ), SUM(a) + SUM(b) FROM "default".c WHERE dt_epoch_ms >= 1735690800000 AND dt_epoch_ms < 1759328588000 AND locality != 'US' GROUP BY CAST( DATE_TRUNC( 'day', CAST( DATETIMECONVERT( dt_epoch_ms, '1:MILLISECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '1:MILLISECONDS' ) AS TIMESTAMP ) ) AS TIMESTAMP ) LIMIT 10000 """ result = sqlglot.parse_one(sql, Pinot).sql(Pinot) # Verify DATE_TRUNC and CAST are preserved assert "DATE_TRUNC" in result assert "CAST" in result # Verify these are NOT converted to MySQL functions assert "TIMESTAMP(DATETIMECONVERT" not in result assert result.count("DATE_TRUNC") == 2 # Should appear twice (SELECT and GROUP BY) def test_pinot_date_add_parsing() -> None: """ Test that Pinot's DATE_ADD function with Presto-like syntax can be parsed. """ from superset.sql.parse import SQLScript sql = """ SELECT dt_epoch_ms FROM my_table WHERE dt_epoch_ms >= date_add('day', -180, now()) """ script = SQLScript(sql, "pinot") assert len(script.statements) == 1 assert not script.has_mutation() def test_pinot_date_add_simple() -> None: """ Test parsing of simple DATE_ADD expressions. """ test_cases = [ "date_add('day', -180, now())", "DATE_ADD('month', 5, current_timestamp())", "date_add('year', 1, my_date_column)", ] for sql in test_cases: parsed = sqlglot.parse_one(sql, Pinot) assert parsed is not None # Verify that it generates valid SQL generated = parsed.sql(dialect=Pinot) assert "DATE_ADD" in generated.upper() def test_pinot_date_add_unit_quoted() -> None: """ Test that DATE_ADD preserves quotes around the unit argument. Pinot requires the unit to be a quoted string, not an identifier. """ sql = "dt_epoch_ms >= date_add('day', -180, now())" result = sqlglot.parse_one(sql, Pinot).sql(Pinot) # The unit should be quoted: 'DAY' not DAY assert "DATE_ADD('DAY', -180, NOW())" in result assert "DATE_ADD(DAY," not in result def test_pinot_date_sub_parsing() -> None: """ Test that Pinot's DATE_SUB function with Presto-like syntax can be parsed. """ from superset.sql.parse import SQLScript sql = "SELECT * FROM my_table WHERE dt >= date_sub('day', 7, now())" script = SQLScript(sql, "pinot") assert len(script.statements) == 1 assert not script.has_mutation() def test_pinot_date_sub_simple() -> None: """ Test parsing of simple DATE_SUB expressions. """ test_cases = [ "date_sub('day', 7, now())", "DATE_SUB('month', 3, current_timestamp())", "date_sub('hour', 24, my_date_column)", ] for sql in test_cases: parsed = sqlglot.parse_one(sql, Pinot) assert parsed is not None # Verify that it generates valid SQL generated = parsed.sql(dialect=Pinot) assert "DATE_SUB" in generated.upper() def test_pinot_date_sub_unit_quoted() -> None: """ Test that DATE_SUB preserves quotes around the unit argument. Pinot requires the unit to be a quoted string, not an identifier. """ sql = "dt_epoch_ms >= date_sub('day', -180, now())" result = sqlglot.parse_one(sql, Pinot).sql(Pinot) # The unit should be quoted: 'DAY' not DAY assert "DATE_SUB('DAY', -180, NOW())" in result assert "DATE_SUB(DAY," not in result def test_substr_cross_dialect_generation() -> None: """ Test that SUBSTR is preserved when generating Pinot SQL. Note that the MySQL dialect (in which Pinot is based) uses SUBSTRING instead of SUBSTR. """ # Parse with Pinot dialect pinot_sql = "SELECT SUBSTR('hello', 0, 3) FROM users" parsed = sqlglot.parse_one(pinot_sql, Pinot) # Generate back to Pinot → should preserve SUBSTR pinot_output = parsed.sql(dialect=Pinot) assert "SUBSTR(" in pinot_output assert "SUBSTRING(" not in pinot_output # Generate to MySQL → should convert to SUBSTRING mysql_output = parsed.sql(dialect="mysql") assert "SUBSTRING(" in mysql_output assert pinot_output != mysql_output # They should be different @pytest.mark.parametrize( "function_name,sample_args", [ # Math functions ("ABS", "-5"), ("CEIL", "3.14"), ("FLOOR", "3.14"), ("EXP", "2"), ("LN", "10"), ("SQRT", "16"), ("ROUNDDECIMAL", "3.14159, 2"), ("ADD", "1, 2, 3"), ("SUB", "10, 3"), ("MULT", "5, 4"), ("MOD", "10, 3"), # String functions ("UPPER", "'hello'"), ("LOWER", "'HELLO'"), ("REVERSE", "'hello'"), ("SUBSTR", "'hello', 0, 3"), ("CONCAT", "'hello', ' ', 'world'"), ("TRIM", "' hello '"), ("LTRIM", "' hello'"), ("RTRIM", "'hello '"), ("LENGTH", "'hello'"), ("STRPOS", "'hello', 'l', 1"), ("STARTSWITH", "'hello', 'he'"), ("REPLACE", "'hello', 'l', 'r'"), ("RPAD", "'hello', 10, 'x'"), ("LPAD", "'hello', 10, 'x'"), ("CODEPOINT", "'A'"), ("CHR", "65"), ("regexpExtract", "'foo123bar', '[0-9]+'"), ("regexpReplace", "'hello', 'l', 'r'"), ("remove", "'hello', 'l'"), ("urlEncoding", "'hello world'"), ("urlDecoding", "'hello%20world'"), ("fromBase64", "'aGVsbG8='"), ("toUtf8", "'hello'"), ("isSubnetOf", "'192.168.1.1', '192.168.0.0/16'"), # DateTime functions ("DATETRUNC", "'day', timestamp_col"), ("DATETIMECONVERT", "dt_col, '1:HOURS:EPOCH', '1:DAYS:EPOCH', '1:DAYS'"), ("TIMECONVERT", "timestamp_col, 'MILLISECONDS', 'SECONDS'"), ("NOW", ""), ("AGO", "'P1D'"), ("YEAR", "timestamp_col"), ("QUARTER", "timestamp_col"), ("MONTH", "timestamp_col"), ("WEEK", "timestamp_col"), ("DAY", "timestamp_col"), ("HOUR", "timestamp_col"), ("MINUTE", "timestamp_col"), ("SECOND", "timestamp_col"), ("MILLISECOND", "timestamp_col"), ("DAYOFWEEK", "timestamp_col"), ("DAYOFYEAR", "timestamp_col"), ("YEAROFWEEK", "timestamp_col"), ("toEpochSeconds", "timestamp_col"), ("toEpochMinutes", "timestamp_col"), ("toEpochHours", "timestamp_col"), ("toEpochDays", "timestamp_col"), ("fromEpochSeconds", "1234567890"), ("fromEpochMinutes", "20576131"), ("fromEpochHours", "342935"), ("fromEpochDays", "14288"), ("toDateTime", "timestamp_col, 'yyyy-MM-dd'"), ("fromDateTime", "'2024-01-01', 'yyyy-MM-dd'"), ("timezoneHour", "timestamp_col"), ("timezoneMinute", "timestamp_col"), ("DATE_ADD", "'day', 7, NOW()"), ("DATE_SUB", "'day', 7, NOW()"), ("TIMESTAMPADD", "'day', 7, timestamp_col"), ("TIMESTAMPDIFF", "'day', timestamp1, timestamp2"), ("dateTrunc", "'day', timestamp_col"), ("dateDiff", "'day', timestamp1, timestamp2"), ("dateAdd", "'day', 7, timestamp_col"), ("dateBin", "'day', timestamp_col, NOW()"), ("toIso8601", "timestamp_col"), ("fromIso8601", "'2024-01-01T00:00:00Z'"), # Aggregation functions ("COUNT", "*"), ("SUM", "amount"), ("AVG", "value"), ("MIN", "value"), ("MAX", "value"), ("DISTINCTCOUNT", "user_id"), ("DISTINCTCOUNTBITMAP", "user_id"), ("DISTINCTCOUNTHLL", "user_id"), ("DISTINCTCOUNTRAWHLL", "user_id"), ("DISTINCTCOUNTHLLPLUS", "user_id"), ("DISTINCTCOUNTRAWHLLPLUS", "user_id"), ("DISTINCTCOUNTSMARTHLL", "user_id"), ("DISTINCTCOUNTCPCSKETCH", "user_id"), ("DISTINCTCOUNTRAWCPCSKETCH", "user_id"), ("DISTINCTCOUNTTHETASKETCH", "user_id"), ("DISTINCTCOUNTRAWTHETASKETCH", "user_id"), ("DISTINCTCOUNTTUPLESKETCH", "user_id"), ("DISTINCTCOUNTRAWINTEGERSUMTUPLESKETCH", "user_id"), ("DISTINCTCOUNTULL", "user_id"), ("DISTINCTCOUNTRAWULL", "user_id"), ("SEGMENTPARTITIONEDDISTINCTCOUNT", "user_id"), ("SUMVALUESINTEGERSUMTUPLESKETCH", "value"), ("PERCENTILE", "value, 95"), ("PERCENTILEEST", "value, 95"), ("PERCENTILETDIGEST", "value, 95"), ("PERCENTILESMARTTDIGEST", "value, 95"), ("PERCENTILEKLL", "value, 95"), ("PERCENTILEKLLRAW", "value, 95"), ("HISTOGRAM", "value, 10"), ("MODE", "category"), ("MINMAXRANGE", "value"), ("SUMPRECISION", "value, 10"), ("ARG_MIN", "value, id"), ("ARG_MAX", "value, id"), ("COVAR_POP", "x, y"), ("COVAR_SAMP", "x, y"), ("LASTWITHTIME", "value, timestamp_col, 'LONG'"), ("FIRSTWITHTIME", "value, timestamp_col, 'LONG'"), ("ARRAY_AGG", "value"), # Multi-value functions ("COUNTMV", "tags"), ("MAXMV", "scores"), ("MINMV", "scores"), ("SUMMV", "scores"), ("AVGMV", "scores"), ("MINMAXRANGEMV", "scores"), ("PERCENTILEMV", "scores, 95"), ("PERCENTILEESTMV", "scores, 95"), ("PERCENTILETDIGESTMV", "scores, 95"), ("PERCENTILEKLLMV", "scores, 95"), ("DISTINCTCOUNTMV", "tags"), ("DISTINCTCOUNTBITMAPMV", "tags"), ("DISTINCTCOUNTHLLMV", "tags"), ("DISTINCTCOUNTRAWHLLMV", "tags"), ("DISTINCTCOUNTHLLPLUSMV", "tags"), ("DISTINCTCOUNTRAWHLLPLUSMV", "tags"), ("ARRAYLENGTH", "array_col"), ("MAP_VALUE", "map_col, 'key'"), ("VALUEIN", "value, 'val1', 'val2'"), # JSON functions ("JSONEXTRACTSCALAR", "json_col, '$.name', 'STRING'"), ("JSONEXTRACTKEY", "json_col, '$.data'"), ("JSONFORMAT", "json_col"), ("JSONPATH", "json_col, '$.name'"), ("JSONPATHLONG", "json_col, '$.id'"), ("JSONPATHDOUBLE", "json_col, '$.price'"), ("JSONPATHSTRING", "json_col, '$.name'"), ("JSONPATHARRAY", "json_col, '$.items'"), ("JSONPATHARRAYDEFAULTEMPTY", "json_col, '$.items'"), ("TOJSONMAPSTR", "map_col"), ("JSON_MATCH", "json_col, '\"$.name\"=''value'''"), ("JSON_EXTRACT_SCALAR", "json_col, '$.name', 'STRING'"), # Array functions ("arrayReverseInt", "int_array"), ("arrayReverseString", "string_array"), ("arraySortInt", "int_array"), ("arraySortString", "string_array"), ("arrayIndexOfInt", "int_array, 5"), ("arrayIndexOfString", "string_array, 'value'"), ("arrayContainsInt", "int_array, 5"), ("arrayContainsString", "string_array, 'value'"), ("arraySliceInt", "int_array, 0, 3"), ("arraySliceString", "string_array, 0, 3"), ("arrayDistinctInt", "int_array"), ("arrayDistinctString", "string_array"), ("arrayRemoveInt", "int_array, 5"), ("arrayRemoveString", "string_array, 'value'"), ("arrayUnionInt", "int_array1, int_array2"), ("arrayUnionString", "string_array1, string_array2"), ("arrayConcatInt", "int_array1, int_array2"), ("arrayConcatString", "string_array1, string_array2"), ("arrayElementAtInt", "int_array, 0"), ("arrayElementAtString", "string_array, 0"), ("arraySumInt", "int_array"), ("arrayValueConstructor", "1, 2, 3"), ("arrayToString", "array_col, ','"), # Geospatial functions ("ST_DISTANCE", "point1, point2"), ("ST_CONTAINS", "polygon, point"), ("ST_AREA", "polygon"), ("ST_GEOMFROMTEXT", "'POINT(1 2)'"), ("ST_GEOMFROMWKB", "wkb_col"), ("ST_GEOGFROMWKB", "wkb_col"), ("ST_GEOGFROMTEXT", "'POINT(1 2)'"), ("ST_POINT", "1.0, 2.0"), ("ST_POLYGON", "'POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'"), ("ST_ASBINARY", "geom_col"), ("ST_ASTEXT", "geom_col"), ("ST_GEOMETRYTYPE", "geom_col"), ("ST_EQUALS", "geom1, geom2"), ("ST_WITHIN", "geom1, geom2"), ("ST_UNION", "geom1, geom2"), ("ST_GEOMFROMGEOJSON", '\'{"type":"Point","coordinates":[1,2]}\''), ("ST_GEOGFROMGEOJSON", '\'{"type":"Point","coordinates":[1,2]}\''), ("ST_ASGEOJSON", "geom_col"), ("toSphericalGeography", "geom_col"), ("toGeometry", "geog_col"), # Binary/Hash functions ("SHA", "'hello'"), ("SHA256", "'hello'"), ("SHA512", "'hello'"), ("SHA224", "'hello'"), ("MD5", "'hello'"), ("MD2", "'hello'"), ("toBase64", "'hello'"), ("fromUtf8", "bytes_col"), ("MurmurHash2", "'hello'"), ("MurmurHash3Bit32", "'hello'"), # Window functions ("ROW_NUMBER", ""), ("RANK", ""), ("DENSE_RANK", ""), # Funnel analysis ("FunnelMaxStep", "event_col, 'step1', 'step2', 'step3'"), ("FunnelMatchStep", "event_col, 'step1', 'step2', 'step3'"), ("FunnelCompleteCount", "event_col, 'step1', 'step2', 'step3'"), # Text search ("TEXT_MATCH", "text_col, 'search query'"), # Vector functions ("VECTOR_SIMILARITY", "vector1, vector2"), ("l2_distance", "vector1, vector2"), # Lookup ("LOOKUP", "'lookupTable', 'lookupColumn', 'keyColumn', keyValue"), # URL functions ("urlProtocol", "'https://example.com/path'"), ("urlDomain", "'https://example.com/path'"), ("urlPath", "'https://example.com/path'"), ("urlPort", "'https://example.com:8080/path'"), ("urlEncode", "'hello world'"), ("urlDecode", "'hello%20world'"), # Conditional ("COALESCE", "val1, val2, 'default'"), ("NULLIF", "val1, val2"), ("GREATEST", "1, 2, 3"), ("LEAST", "1, 2, 3"), # Other ("REGEXP_LIKE", "'hello', 'h.*'"), ("GROOVY", "'{return arg0 + arg1}', col1, col2"), ], ) def test_pinot_function_names_preserved(function_name: str, sample_args: str) -> None: """ Test that Pinot function names are preserved during parse/generate roundtrip. This ensures that when we parse Pinot SQL and generate it back, the function names remain unchanged. This is critical for maintaining compatibility with Pinot's function library. """ # Special handling for window functions if function_name in ["ROW_NUMBER", "RANK", "DENSE_RANK"]: sql = f"SELECT {function_name}() OVER (ORDER BY col) FROM table" # noqa: S608 else: sql = f"SELECT {function_name}({sample_args}) FROM table" # noqa: S608 # Parse with Pinot dialect parsed = sqlglot.parse_one(sql, Pinot) # Generate back to Pinot generated = parsed.sql(dialect=Pinot) # The function name should be preserved (case-insensitive check) assert function_name.upper() in generated.upper(), ( f"Function {function_name} not preserved in output: {generated}" )