fix(pinot): restrict types in dialect (#35337)

This commit is contained in:
Beto Dealmeida
2025-09-30 16:34:53 -04:00
committed by GitHub
parent d51b35f61b
commit bf88d9bb1c
2 changed files with 129 additions and 0 deletions

View File

@@ -346,3 +346,78 @@ SELECT DISTINCT
FROM "products"
""".strip()
)
def test_cast_to_string() -> None:
"""
Test that CAST to STRING is preserved (not converted to CHAR).
"""
sql = "SELECT CAST(cohort_size AS STRING) FROM table"
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)
assert "STRING" in generated
assert "CHAR" not in generated
def test_concat_with_cast_string() -> None:
"""
Test CONCAT with CAST to STRING - verifies the original issue is fixed.
"""
sql = """
SELECT concat(a, cast(b AS string), ' - ')
FROM "default".c"""
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)
# Verify STRING type is preserved (not converted to CHAR)
assert "STRING" in generated or "string" in generated.lower()
assert "CHAR" not in generated
@pytest.mark.parametrize(
"cast_type, expected_type",
[
("INT", "INT"),
("TINYINT", "INT"),
("SMALLINT", "INT"),
("BIGINT", "LONG"),
("LONG", "LONG"),
("FLOAT", "FLOAT"),
("DOUBLE", "DOUBLE"),
("BOOLEAN", "BOOLEAN"),
("TIMESTAMP", "TIMESTAMP"),
("STRING", "STRING"),
("VARCHAR", "STRING"),
("CHAR", "STRING"),
("TEXT", "STRING"),
("BYTES", "BYTES"),
("BINARY", "BYTES"),
("VARBINARY", "BYTES"),
("JSON", "JSON"),
],
)
def test_type_mappings(cast_type: str, expected_type: str) -> None:
"""
Test that Pinot type mappings work correctly for all basic types.
"""
sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)
assert expected_type in generated
def test_unsigned_type() -> None:
"""
Test that unsigned integer types are handled correctly.
Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method.
"""
from sqlglot import exp
# Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING
dt = exp.DataType(this=exp.DataType.Type.UBIGINT)
result = Pinot.Generator().datatype_sql(dt)
assert "UNSIGNED" in result
assert "BIGINT" in result