fix(pinot): restrict types in dialect (#35337)

This commit is contained in:
Beto Dealmeida
2025-09-30 16:34:53 -04:00
committed by GitHub
parent d51b35f61b
commit bf88d9bb1c
2 changed files with 129 additions and 0 deletions

View File

@@ -24,7 +24,9 @@ double quotes are used for identifiers instead of string literals.
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.mysql import MySQL
from sqlglot.tokens import TokenType
class Pinot(MySQL):
@@ -41,3 +43,55 @@ class Pinot(MySQL):
QUOTES = ["'"] # Only single quotes for strings
IDENTIFIERS = ['"', "`"] # Backticks and double quotes for identifiers
STRING_ESCAPES = ["'", "\\"] # Remove double quote from string escapes
KEYWORDS = {
**MySQL.Tokenizer.KEYWORDS,
"STRING": TokenType.TEXT,
"LONG": TokenType.BIGINT,
"BYTES": TokenType.VARBINARY,
}
class Generator(MySQL.Generator):
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "INT",
exp.DataType.Type.SMALLINT: "INT",
exp.DataType.Type.INT: "INT",
exp.DataType.Type.BIGINT: "LONG",
exp.DataType.Type.FLOAT: "FLOAT",
exp.DataType.Type.DOUBLE: "DOUBLE",
exp.DataType.Type.BOOLEAN: "BOOLEAN",
exp.DataType.Type.TIMESTAMP: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.CHAR: "STRING",
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.BINARY: "BYTES",
exp.DataType.Type.VARBINARY: "BYTES",
exp.DataType.Type.JSON: "JSON",
}
# Override MySQL's CAST_MAPPING - don't convert integer or string types
CAST_MAPPING = {
exp.DataType.Type.LONGBLOB: exp.DataType.Type.VARBINARY,
exp.DataType.Type.MEDIUMBLOB: exp.DataType.Type.VARBINARY,
exp.DataType.Type.TINYBLOB: exp.DataType.Type.VARBINARY,
exp.DataType.Type.UBIGINT: "UNSIGNED",
}
def datatype_sql(self, expression: exp.DataType) -> str:
# Don't use MySQL's VARCHAR size requirement logic
# Just use TYPE_MAPPING for all types
type_value = expression.this
type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)
interior = self.expressions(expression, flat=True)
nested = f"({interior})" if interior else ""
if expression.this in self.UNSIGNED_TYPE_MAPPING:
return f"{type_sql} UNSIGNED{nested}"
return f"{type_sql}{nested}"