fix(pinot): more functions (#35451)

This commit is contained in:
Beto Dealmeida
2025-10-02 13:01:47 -04:00
committed by GitHub
parent 553204e613
commit 3202ff4b3f
2 changed files with 297 additions and 0 deletions

View File

@@ -612,3 +612,270 @@ def test_substr_cross_dialect_generation() -> None:
mysql_output = parsed.sql(dialect="mysql")
assert "SUBSTRING(" in mysql_output
assert pinot_output != mysql_output # They should be different
@pytest.mark.parametrize(
"function_name,sample_args",
[
# Math functions
("ABS", "-5"),
("CEIL", "3.14"),
("FLOOR", "3.14"),
("EXP", "2"),
("LN", "10"),
("SQRT", "16"),
("ROUNDDECIMAL", "3.14159, 2"),
("ADD", "1, 2, 3"),
("SUB", "10, 3"),
("MULT", "5, 4"),
("MOD", "10, 3"),
# String functions
("UPPER", "'hello'"),
("LOWER", "'HELLO'"),
("REVERSE", "'hello'"),
("SUBSTR", "'hello', 0, 3"),
("CONCAT", "'hello', ' ', 'world'"),
("TRIM", "' hello '"),
("LTRIM", "' hello'"),
("RTRIM", "'hello '"),
("LENGTH", "'hello'"),
("STRPOS", "'hello', 'l', 1"),
("STARTSWITH", "'hello', 'he'"),
("REPLACE", "'hello', 'l', 'r'"),
("RPAD", "'hello', 10, 'x'"),
("LPAD", "'hello', 10, 'x'"),
("CODEPOINT", "'A'"),
("CHR", "65"),
("regexpExtract", "'foo123bar', '[0-9]+'"),
("regexpReplace", "'hello', 'l', 'r'"),
("remove", "'hello', 'l'"),
("urlEncoding", "'hello world'"),
("urlDecoding", "'hello%20world'"),
("fromBase64", "'aGVsbG8='"),
("toUtf8", "'hello'"),
("isSubnetOf", "'192.168.1.1', '192.168.0.0/16'"),
# DateTime functions
("DATETRUNC", "'day', timestamp_col"),
("DATETIMECONVERT", "dt_col, '1:HOURS:EPOCH', '1:DAYS:EPOCH', '1:DAYS'"),
("TIMECONVERT", "timestamp_col, 'MILLISECONDS', 'SECONDS'"),
("NOW", ""),
("AGO", "'P1D'"),
("YEAR", "timestamp_col"),
("QUARTER", "timestamp_col"),
("MONTH", "timestamp_col"),
("WEEK", "timestamp_col"),
("DAY", "timestamp_col"),
("HOUR", "timestamp_col"),
("MINUTE", "timestamp_col"),
("SECOND", "timestamp_col"),
("MILLISECOND", "timestamp_col"),
("DAYOFWEEK", "timestamp_col"),
("DAYOFYEAR", "timestamp_col"),
("YEAROFWEEK", "timestamp_col"),
("toEpochSeconds", "timestamp_col"),
("toEpochMinutes", "timestamp_col"),
("toEpochHours", "timestamp_col"),
("toEpochDays", "timestamp_col"),
("fromEpochSeconds", "1234567890"),
("fromEpochMinutes", "20576131"),
("fromEpochHours", "342935"),
("fromEpochDays", "14288"),
("toDateTime", "timestamp_col, 'yyyy-MM-dd'"),
("fromDateTime", "'2024-01-01', 'yyyy-MM-dd'"),
("timezoneHour", "timestamp_col"),
("timezoneMinute", "timestamp_col"),
("DATE_ADD", "'day', 7, NOW()"),
("DATE_SUB", "'day', 7, NOW()"),
("TIMESTAMPADD", "'day', 7, timestamp_col"),
("TIMESTAMPDIFF", "'day', timestamp1, timestamp2"),
("dateTrunc", "'day', timestamp_col"),
("dateDiff", "'day', timestamp1, timestamp2"),
("dateAdd", "'day', 7, timestamp_col"),
("dateBin", "'day', timestamp_col, NOW()"),
("toIso8601", "timestamp_col"),
("fromIso8601", "'2024-01-01T00:00:00Z'"),
# Aggregation functions
("COUNT", "*"),
("SUM", "amount"),
("AVG", "value"),
("MIN", "value"),
("MAX", "value"),
("DISTINCTCOUNT", "user_id"),
("DISTINCTCOUNTBITMAP", "user_id"),
("DISTINCTCOUNTHLL", "user_id"),
("DISTINCTCOUNTRAWHLL", "user_id"),
("DISTINCTCOUNTHLLPLUS", "user_id"),
("DISTINCTCOUNTRAWHLLPLUS", "user_id"),
("DISTINCTCOUNTSMARTHLL", "user_id"),
("DISTINCTCOUNTCPCSKETCH", "user_id"),
("DISTINCTCOUNTRAWCPCSKETCH", "user_id"),
("DISTINCTCOUNTTHETASKETCH", "user_id"),
("DISTINCTCOUNTRAWTHETASKETCH", "user_id"),
("DISTINCTCOUNTTUPLESKETCH", "user_id"),
("DISTINCTCOUNTRAWINTEGERSUMTUPLESKETCH", "user_id"),
("DISTINCTCOUNTULL", "user_id"),
("DISTINCTCOUNTRAWULL", "user_id"),
("SEGMENTPARTITIONEDDISTINCTCOUNT", "user_id"),
("SUMVALUESINTEGERSUMTUPLESKETCH", "value"),
("PERCENTILE", "value, 95"),
("PERCENTILEEST", "value, 95"),
("PERCENTILETDIGEST", "value, 95"),
("PERCENTILESMARTTDIGEST", "value, 95"),
("PERCENTILEKLL", "value, 95"),
("PERCENTILEKLLRAW", "value, 95"),
("HISTOGRAM", "value, 10"),
("MODE", "category"),
("MINMAXRANGE", "value"),
("SUMPRECISION", "value, 10"),
("ARG_MIN", "value, id"),
("ARG_MAX", "value, id"),
("COVAR_POP", "x, y"),
("COVAR_SAMP", "x, y"),
("LASTWITHTIME", "value, timestamp_col, 'LONG'"),
("FIRSTWITHTIME", "value, timestamp_col, 'LONG'"),
("ARRAY_AGG", "value"),
# Multi-value functions
("COUNTMV", "tags"),
("MAXMV", "scores"),
("MINMV", "scores"),
("SUMMV", "scores"),
("AVGMV", "scores"),
("MINMAXRANGEMV", "scores"),
("PERCENTILEMV", "scores, 95"),
("PERCENTILEESTMV", "scores, 95"),
("PERCENTILETDIGESTMV", "scores, 95"),
("PERCENTILEKLLMV", "scores, 95"),
("DISTINCTCOUNTMV", "tags"),
("DISTINCTCOUNTBITMAPMV", "tags"),
("DISTINCTCOUNTHLLMV", "tags"),
("DISTINCTCOUNTRAWHLLMV", "tags"),
("DISTINCTCOUNTHLLPLUSMV", "tags"),
("DISTINCTCOUNTRAWHLLPLUSMV", "tags"),
("ARRAYLENGTH", "array_col"),
("MAP_VALUE", "map_col, 'key'"),
("VALUEIN", "value, 'val1', 'val2'"),
# JSON functions
("JSONEXTRACTSCALAR", "json_col, '$.name', 'STRING'"),
("JSONEXTRACTKEY", "json_col, '$.data'"),
("JSONFORMAT", "json_col"),
("JSONPATH", "json_col, '$.name'"),
("JSONPATHLONG", "json_col, '$.id'"),
("JSONPATHDOUBLE", "json_col, '$.price'"),
("JSONPATHSTRING", "json_col, '$.name'"),
("JSONPATHARRAY", "json_col, '$.items'"),
("JSONPATHARRAYDEFAULTEMPTY", "json_col, '$.items'"),
("TOJSONMAPSTR", "map_col"),
("JSON_MATCH", "json_col, '\"$.name\"=''value'''"),
("JSON_EXTRACT_SCALAR", "json_col, '$.name', 'STRING'"),
# Array functions
("arrayReverseInt", "int_array"),
("arrayReverseString", "string_array"),
("arraySortInt", "int_array"),
("arraySortString", "string_array"),
("arrayIndexOfInt", "int_array, 5"),
("arrayIndexOfString", "string_array, 'value'"),
("arrayContainsInt", "int_array, 5"),
("arrayContainsString", "string_array, 'value'"),
("arraySliceInt", "int_array, 0, 3"),
("arraySliceString", "string_array, 0, 3"),
("arrayDistinctInt", "int_array"),
("arrayDistinctString", "string_array"),
("arrayRemoveInt", "int_array, 5"),
("arrayRemoveString", "string_array, 'value'"),
("arrayUnionInt", "int_array1, int_array2"),
("arrayUnionString", "string_array1, string_array2"),
("arrayConcatInt", "int_array1, int_array2"),
("arrayConcatString", "string_array1, string_array2"),
("arrayElementAtInt", "int_array, 0"),
("arrayElementAtString", "string_array, 0"),
("arraySumInt", "int_array"),
("arrayValueConstructor", "1, 2, 3"),
("arrayToString", "array_col, ','"),
# Geospatial functions
("ST_DISTANCE", "point1, point2"),
("ST_CONTAINS", "polygon, point"),
("ST_AREA", "polygon"),
("ST_GEOMFROMTEXT", "'POINT(1 2)'"),
("ST_GEOMFROMWKB", "wkb_col"),
("ST_GEOGFROMWKB", "wkb_col"),
("ST_GEOGFROMTEXT", "'POINT(1 2)'"),
("ST_POINT", "1.0, 2.0"),
("ST_POLYGON", "'POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'"),
("ST_ASBINARY", "geom_col"),
("ST_ASTEXT", "geom_col"),
("ST_GEOMETRYTYPE", "geom_col"),
("ST_EQUALS", "geom1, geom2"),
("ST_WITHIN", "geom1, geom2"),
("ST_UNION", "geom1, geom2"),
("ST_GEOMFROMGEOJSON", '\'{"type":"Point","coordinates":[1,2]}\''),
("ST_GEOGFROMGEOJSON", '\'{"type":"Point","coordinates":[1,2]}\''),
("ST_ASGEOJSON", "geom_col"),
("toSphericalGeography", "geom_col"),
("toGeometry", "geog_col"),
# Binary/Hash functions
("SHA", "'hello'"),
("SHA256", "'hello'"),
("SHA512", "'hello'"),
("SHA224", "'hello'"),
("MD5", "'hello'"),
("MD2", "'hello'"),
("toBase64", "'hello'"),
("fromUtf8", "bytes_col"),
("MurmurHash2", "'hello'"),
("MurmurHash3Bit32", "'hello'"),
# Window functions
("ROW_NUMBER", ""),
("RANK", ""),
("DENSE_RANK", ""),
# Funnel analysis
("FunnelMaxStep", "event_col, 'step1', 'step2', 'step3'"),
("FunnelMatchStep", "event_col, 'step1', 'step2', 'step3'"),
("FunnelCompleteCount", "event_col, 'step1', 'step2', 'step3'"),
# Text search
("TEXT_MATCH", "text_col, 'search query'"),
# Vector functions
("VECTOR_SIMILARITY", "vector1, vector2"),
("l2_distance", "vector1, vector2"),
# Lookup
("LOOKUP", "'lookupTable', 'lookupColumn', 'keyColumn', keyValue"),
# URL functions
("urlProtocol", "'https://example.com/path'"),
("urlDomain", "'https://example.com/path'"),
("urlPath", "'https://example.com/path'"),
("urlPort", "'https://example.com:8080/path'"),
("urlEncode", "'hello world'"),
("urlDecode", "'hello%20world'"),
# Conditional
("COALESCE", "val1, val2, 'default'"),
("NULLIF", "val1, val2"),
("GREATEST", "1, 2, 3"),
("LEAST", "1, 2, 3"),
# Other
("REGEXP_LIKE", "'hello', 'h.*'"),
("GROOVY", "'{return arg0 + arg1}', col1, col2"),
],
)
def test_pinot_function_names_preserved(function_name: str, sample_args: str) -> None:
"""
Test that Pinot function names are preserved during parse/generate roundtrip.
This ensures that when we parse Pinot SQL and generate it back, the function
names remain unchanged. This is critical for maintaining compatibility with
Pinot's function library.
"""
# Special handling for window functions
if function_name in ["ROW_NUMBER", "RANK", "DENSE_RANK"]:
sql = f"SELECT {function_name}() OVER (ORDER BY col) FROM table" # noqa: S608
else:
sql = f"SELECT {function_name}({sample_args}) FROM table" # noqa: S608
# Parse with Pinot dialect
parsed = sqlglot.parse_one(sql, Pinot)
# Generate back to Pinot
generated = parsed.sql(dialect=Pinot)
# The function name should be preserved (case-insensitive check)
assert function_name.upper() in generated.upper(), (
f"Function {function_name} not preserved in output: {generated}"
)