mirror of
https://github.com/apache/superset.git
synced 2026-04-18 15:44:57 +00:00
fix(pinot): restrict types in dialect (#35337)
This commit is contained in:
@@ -24,7 +24,9 @@ double quotes are used for identifiers instead of string literals.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
class Pinot(MySQL):
|
||||
@@ -41,3 +43,55 @@ class Pinot(MySQL):
|
||||
QUOTES = ["'"] # Only single quotes for strings
|
||||
IDENTIFIERS = ['"', "`"] # Backticks and double quotes for identifiers
|
||||
STRING_ESCAPES = ["'", "\\"] # Remove double quote from string escapes
|
||||
KEYWORDS = {
|
||||
**MySQL.Tokenizer.KEYWORDS,
|
||||
"STRING": TokenType.TEXT,
|
||||
"LONG": TokenType.BIGINT,
|
||||
"BYTES": TokenType.VARBINARY,
|
||||
}
|
||||
|
||||
class Generator(MySQL.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**MySQL.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TINYINT: "INT",
|
||||
exp.DataType.Type.SMALLINT: "INT",
|
||||
exp.DataType.Type.INT: "INT",
|
||||
exp.DataType.Type.BIGINT: "LONG",
|
||||
exp.DataType.Type.FLOAT: "FLOAT",
|
||||
exp.DataType.Type.DOUBLE: "DOUBLE",
|
||||
exp.DataType.Type.BOOLEAN: "BOOLEAN",
|
||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMP",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
exp.DataType.Type.CHAR: "STRING",
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.BINARY: "BYTES",
|
||||
exp.DataType.Type.VARBINARY: "BYTES",
|
||||
exp.DataType.Type.JSON: "JSON",
|
||||
}
|
||||
|
||||
# Override MySQL's CAST_MAPPING - don't convert integer or string types
|
||||
CAST_MAPPING = {
|
||||
exp.DataType.Type.LONGBLOB: exp.DataType.Type.VARBINARY,
|
||||
exp.DataType.Type.MEDIUMBLOB: exp.DataType.Type.VARBINARY,
|
||||
exp.DataType.Type.TINYBLOB: exp.DataType.Type.VARBINARY,
|
||||
exp.DataType.Type.UBIGINT: "UNSIGNED",
|
||||
}
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
# Don't use MySQL's VARCHAR size requirement logic
|
||||
# Just use TYPE_MAPPING for all types
|
||||
type_value = expression.this
|
||||
type_sql = (
|
||||
self.TYPE_MAPPING.get(type_value, type_value.value)
|
||||
if isinstance(type_value, exp.DataType.Type)
|
||||
else type_value
|
||||
)
|
||||
|
||||
interior = self.expressions(expression, flat=True)
|
||||
nested = f"({interior})" if interior else ""
|
||||
|
||||
if expression.this in self.UNSIGNED_TYPE_MAPPING:
|
||||
return f"{type_sql} UNSIGNED{nested}"
|
||||
|
||||
return f"{type_sql}{nested}"
|
||||
|
||||
@@ -346,3 +346,78 @@ SELECT DISTINCT
|
||||
FROM "products"
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_cast_to_string() -> None:
|
||||
"""
|
||||
Test that CAST to STRING is preserved (not converted to CHAR).
|
||||
"""
|
||||
sql = "SELECT CAST(cohort_size AS STRING) FROM table"
|
||||
ast = sqlglot.parse_one(sql, Pinot)
|
||||
generated = Pinot().generate(expression=ast)
|
||||
|
||||
assert "STRING" in generated
|
||||
assert "CHAR" not in generated
|
||||
|
||||
|
||||
def test_concat_with_cast_string() -> None:
|
||||
"""
|
||||
Test CONCAT with CAST to STRING - verifies the original issue is fixed.
|
||||
"""
|
||||
sql = """
|
||||
SELECT concat(a, cast(b AS string), ' - ')
|
||||
FROM "default".c"""
|
||||
ast = sqlglot.parse_one(sql, Pinot)
|
||||
generated = Pinot().generate(expression=ast)
|
||||
|
||||
# Verify STRING type is preserved (not converted to CHAR)
|
||||
assert "STRING" in generated or "string" in generated.lower()
|
||||
assert "CHAR" not in generated
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cast_type, expected_type",
|
||||
[
|
||||
("INT", "INT"),
|
||||
("TINYINT", "INT"),
|
||||
("SMALLINT", "INT"),
|
||||
("BIGINT", "LONG"),
|
||||
("LONG", "LONG"),
|
||||
("FLOAT", "FLOAT"),
|
||||
("DOUBLE", "DOUBLE"),
|
||||
("BOOLEAN", "BOOLEAN"),
|
||||
("TIMESTAMP", "TIMESTAMP"),
|
||||
("STRING", "STRING"),
|
||||
("VARCHAR", "STRING"),
|
||||
("CHAR", "STRING"),
|
||||
("TEXT", "STRING"),
|
||||
("BYTES", "BYTES"),
|
||||
("BINARY", "BYTES"),
|
||||
("VARBINARY", "BYTES"),
|
||||
("JSON", "JSON"),
|
||||
],
|
||||
)
|
||||
def test_type_mappings(cast_type: str, expected_type: str) -> None:
|
||||
"""
|
||||
Test that Pinot type mappings work correctly for all basic types.
|
||||
"""
|
||||
sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608
|
||||
ast = sqlglot.parse_one(sql, Pinot)
|
||||
generated = Pinot().generate(expression=ast)
|
||||
|
||||
assert expected_type in generated
|
||||
|
||||
|
||||
def test_unsigned_type() -> None:
|
||||
"""
|
||||
Test that unsigned integer types are handled correctly.
|
||||
Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method.
|
||||
"""
|
||||
from sqlglot import exp
|
||||
|
||||
# Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING
|
||||
dt = exp.DataType(this=exp.DataType.Type.UBIGINT)
|
||||
result = Pinot.Generator().datatype_sql(dt)
|
||||
|
||||
assert "UNSIGNED" in result
|
||||
assert "BIGINT" in result
|
||||
|
||||
Reference in New Issue
Block a user