# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest import sqlglot from superset.sql.dialects.pinot import Pinot def test_pinot_dialect_registered() -> None: """Test that Pinot dialect is properly registered.""" from superset.sql.parse import SQLGLOT_DIALECTS assert "pinot" in SQLGLOT_DIALECTS assert SQLGLOT_DIALECTS["pinot"] == Pinot def test_double_quotes_as_identifiers() -> None: """ Test that double quotes are treated as identifiers, not string literals. """ sql = 'SELECT "column_name" FROM "table_name"' ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "column_name" FROM "table_name" """.strip() ) def test_single_quotes_for_strings() -> None: """ Test that single quotes are used for string literals. """ sql = "SELECT * FROM users WHERE name = 'John'" ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT * FROM users WHERE name = 'John' """.strip() ) def test_backticks_as_identifiers() -> None: """ Test that backticks work as identifiers (MySQL-style). Backticks are normalized to double quotes in output. """ sql = "SELECT `column_name` FROM `table_name`" ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "column_name" FROM "table_name" """.strip() ) def test_mixed_identifier_quotes() -> None: """ Test mixing double quotes and backticks for identifiers. All identifiers are normalized to double quotes in output. """ sql = ( 'SELECT "col1", `col2` FROM "table1" JOIN `table2` ON "table1".id = `table2`.id' ) ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "col1", "col2" FROM "table1" JOIN "table2" ON "table1".id = "table2".id """.strip() ) def test_string_with_escaped_quotes() -> None: """ Test string literals with escaped single quotes. """ sql = "SELECT * FROM users WHERE name = 'O''Brien'" ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT * FROM users WHERE name = 'O''Brien' """.strip() ) def test_string_with_backslash_escape() -> None: """ Test string literals with backslash escapes. """ sql = r"SELECT * FROM users WHERE path = 'C:\\Users\\John'" ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast, pretty=True) assert "WHERE" in generated assert "path" in generated @pytest.mark.parametrize( "sql, expected", [ ( 'SELECT COUNT(*) FROM "events" WHERE "type" = \'click\'', """ SELECT COUNT(*) FROM "events" WHERE "type" = 'click' """.strip(), ), ( 'SELECT "user_id", SUM("amount") FROM "transactions" GROUP BY "user_id"', """ SELECT "user_id", SUM("amount") FROM "transactions" GROUP BY "user_id" """.strip(), ), ( "SELECT * FROM \"orders\" WHERE \"status\" IN ('pending', 'shipped')", """ SELECT * FROM "orders" WHERE "status" IN ('pending', 'shipped') """.strip(), ), ], ) def test_various_queries(sql: str, expected: str) -> None: """ Test various SQL queries with Pinot dialect. """ ast = sqlglot.parse_one(sql, Pinot) assert Pinot().generate(expression=ast, pretty=True) == expected def test_aggregate_functions() -> None: """ Test aggregate functions with quoted identifiers. """ sql = """ SELECT "category", COUNT(*), AVG("price"), MAX("quantity") FROM "products" GROUP BY "category" """ ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "category", COUNT(*), AVG("price"), MAX("quantity") FROM "products" GROUP BY "category" """.strip() ) def test_join_with_quoted_identifiers() -> None: """ Test JOIN operations with double-quoted identifiers. """ sql = """ SELECT "u"."name", "o"."total" FROM "users" AS "u" JOIN "orders" AS "o" ON "u"."id" = "o"."user_id" """ ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "u"."name", "o"."total" FROM "users" AS "u" JOIN "orders" AS "o" ON "u"."id" = "o"."user_id" """.strip() ) def test_subquery_with_quoted_identifiers() -> None: """ Test subqueries with double-quoted identifiers. """ sql = 'SELECT * FROM (SELECT "id", "name" FROM "users") AS "subquery"' ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT * FROM ( SELECT "id", "name" FROM "users" ) AS "subquery" """.strip() ) def test_case_expression() -> None: """ Test CASE expressions with quoted identifiers. """ sql = """ SELECT "name", CASE WHEN "age" < 18 THEN 'minor' WHEN "age" >= 18 THEN 'adult' END AS "category" FROM "persons" """ ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast, pretty=True) assert '"name"' in generated assert '"age"' in generated assert '"category"' in generated assert "'minor'" in generated assert "'adult'" in generated def test_cte_with_quoted_identifiers() -> None: """ Test Common Table Expressions (CTE) with quoted identifiers. """ sql = """ WITH "high_value_orders" AS ( SELECT * FROM "orders" WHERE "total" > 1000 ) SELECT "customer_id", COUNT(*) FROM "high_value_orders" GROUP BY "customer_id" """ ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast, pretty=True) assert 'WITH "high_value_orders" AS' in generated assert '"orders"' in generated assert '"total"' in generated assert '"customer_id"' in generated def test_order_by_with_quoted_identifiers() -> None: """ Test ORDER BY clause with quoted identifiers. SQLGlot explicitly includes ASC in the output. """ sql = 'SELECT "name", "salary" FROM "employees" ORDER BY "salary" DESC, "name" ASC' ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT "name", "salary" FROM "employees" ORDER BY "salary" DESC, "name" ASC """.strip() ) def test_limit_and_offset() -> None: """ Test LIMIT and OFFSET clauses. """ sql = 'SELECT * FROM "products" LIMIT 10 OFFSET 20' ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast, pretty=True) assert '"products"' in generated assert "LIMIT 10" in generated def test_distinct() -> None: """ Test DISTINCT keyword with quoted identifiers. """ sql = 'SELECT DISTINCT "category" FROM "products"' ast = sqlglot.parse_one(sql, Pinot) assert ( Pinot().generate(expression=ast, pretty=True) == """ SELECT DISTINCT "category" FROM "products" """.strip() ) def test_cast_to_string() -> None: """ Test that CAST to STRING is preserved (not converted to CHAR). """ sql = "SELECT CAST(cohort_size AS STRING) FROM table" ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast) assert "STRING" in generated assert "CHAR" not in generated def test_concat_with_cast_string() -> None: """ Test CONCAT with CAST to STRING - verifies the original issue is fixed. """ sql = """ SELECT concat(a, cast(b AS string), ' - ') FROM "default".c""" ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast) # Verify STRING type is preserved (not converted to CHAR) assert "STRING" in generated or "string" in generated.lower() assert "CHAR" not in generated @pytest.mark.parametrize( "cast_type, expected_type", [ ("INT", "INT"), ("TINYINT", "INT"), ("SMALLINT", "INT"), ("BIGINT", "LONG"), ("LONG", "LONG"), ("FLOAT", "FLOAT"), ("DOUBLE", "DOUBLE"), ("BOOLEAN", "BOOLEAN"), ("TIMESTAMP", "TIMESTAMP"), ("STRING", "STRING"), ("VARCHAR", "STRING"), ("CHAR", "STRING"), ("TEXT", "STRING"), ("BYTES", "BYTES"), ("BINARY", "BYTES"), ("VARBINARY", "BYTES"), ("JSON", "JSON"), ], ) def test_type_mappings(cast_type: str, expected_type: str) -> None: """ Test that Pinot type mappings work correctly for all basic types. """ sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608 ast = sqlglot.parse_one(sql, Pinot) generated = Pinot().generate(expression=ast) assert expected_type in generated def test_unsigned_type() -> None: """ Test that unsigned integer types are handled correctly. Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method. """ from sqlglot import exp # Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING dt = exp.DataType(this=exp.DataType.Type.UBIGINT) result = Pinot.Generator().datatype_sql(dt) assert "UNSIGNED" in result assert "BIGINT" in result def test_date_trunc_preserved() -> None: """ Test that DATE_TRUNC is preserved and not converted to MySQL's DATE() function. """ sql = "SELECT DATE_TRUNC('day', dt_column) FROM table" result = sqlglot.parse_one(sql, Pinot).sql(Pinot) assert "DATE_TRUNC" in result assert "date_trunc('day'" in result.lower() # Should not be converted to MySQL's DATE() function assert result != "SELECT DATE(dt_column) FROM table" def test_cast_timestamp_preserved() -> None: """ Test that CAST AS TIMESTAMP is preserved and not converted to TIMESTAMP() function. """ sql = "SELECT CAST(dt_column AS TIMESTAMP) FROM table" result = sqlglot.parse_one(sql, Pinot).sql(Pinot) assert "CAST" in result assert "AS TIMESTAMP" in result # Should not be converted to MySQL's TIMESTAMP() function assert "TIMESTAMP(dt_column)" not in result def test_date_trunc_with_cast_timestamp() -> None: """ Test the original complex query with DATE_TRUNC and CAST AS TIMESTAMP. Verifies that both are preserved in parse/generate round-trip. """ sql = """ SELECT CAST( DATE_TRUNC( 'day', CAST( DATETIMECONVERT( dt_epoch_ms, '1:MILLISECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '1:MILLISECONDS' ) AS TIMESTAMP ) ) AS TIMESTAMP ), SUM(a) + SUM(b) FROM "default".c WHERE dt_epoch_ms >= 1735690800000 AND dt_epoch_ms < 1759328588000 AND locality != 'US' GROUP BY CAST( DATE_TRUNC( 'day', CAST( DATETIMECONVERT( dt_epoch_ms, '1:MILLISECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '1:MILLISECONDS' ) AS TIMESTAMP ) ) AS TIMESTAMP ) LIMIT 10000 """ result = sqlglot.parse_one(sql, Pinot).sql(Pinot) # Verify DATE_TRUNC and CAST are preserved assert "DATE_TRUNC" in result assert "CAST" in result # Verify these are NOT converted to MySQL functions assert "TIMESTAMP(DATETIMECONVERT" not in result assert result.count("DATE_TRUNC") == 2 # Should appear twice (SELECT and GROUP BY)