diff --git a/superset/sql/dialects/pinot.py b/superset/sql/dialects/pinot.py index f667b9bdabe..2e7cbe9ed5c 100644 --- a/superset/sql/dialects/pinot.py +++ b/superset/sql/dialects/pinot.py @@ -26,6 +26,7 @@ from __future__ import annotations from sqlglot import exp from sqlglot.dialects.mysql import MySQL +from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -50,6 +51,16 @@ class Pinot(MySQL): "BYTES": TokenType.VARBINARY, } + class Parser(MySQL.Parser): + FUNCTIONS = { + **MySQL.Parser.FUNCTIONS, + "DATE_ADD": lambda args: exp.DateAdd( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ), + } + class Generator(MySQL.Generator): TYPE_MAPPING = { **MySQL.Generator.TYPE_MAPPING, @@ -80,6 +91,12 @@ class Pinot(MySQL): TRANSFORMS = { **MySQL.Generator.TRANSFORMS, + exp.DateAdd: lambda self, e: self.func( + "DATE_ADD", + exp.Literal.string(str(e.args.get("unit").name)), + e.args.get("expression"), + e.this, + ), } # Remove DATE_TRUNC transformation - Pinot supports standard SQL DATE_TRUNC TRANSFORMS.pop(exp.DateTrunc, None) diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 822b8ec79be..d5392115f0e 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -552,14 +552,16 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): try: statements = sqlglot.parse(script, dialect=dialect) except sqlglot.errors.ParseError as ex: - error = ex.errors[0] - raise SupersetParseError( - script, - engine, - highlight=error["highlight"], - line=error["line"], - column=error["col"], - ) from ex + kwargs = ( + { + "highlight": ex.errors[0]["highlight"], + "line": ex.errors[0]["line"], + "column": ex.errors[0]["col"], + } + if ex.errors + else {} + ) + raise SupersetParseError(script, engine, **kwargs) from ex except sqlglot.errors.SqlglotError as ex: raise SupersetParseError( script, diff --git a/tests/unit_tests/sql/dialects/pinot_tests.py b/tests/unit_tests/sql/dialects/pinot_tests.py index 226a04e0f1d..bd2c0003325 100644 --- a/tests/unit_tests/sql/dialects/pinot_tests.py +++ b/tests/unit_tests/sql/dialects/pinot_tests.py @@ -22,7 +22,9 @@ from superset.sql.dialects.pinot import Pinot def test_pinot_dialect_registered() -> None: - """Test that Pinot dialect is properly registered.""" + """ + Test that Pinot dialect is properly registered. + """ from superset.sql.parse import SQLGLOT_DIALECTS assert "pinot" in SQLGLOT_DIALECTS @@ -498,3 +500,49 @@ LIMIT # Verify these are NOT converted to MySQL functions assert "TIMESTAMP(DATETIMECONVERT" not in result assert result.count("DATE_TRUNC") == 2 # Should appear twice (SELECT and GROUP BY) + + +def test_pinot_date_add_parsing() -> None: + """ + Test that Pinot's DATE_ADD function with Presto-like syntax can be parsed. + """ + from superset.sql.parse import SQLScript + + sql = """ +SELECT dt_epoch_ms FROM my_table WHERE dt_epoch_ms >= date_add('day', -180, now()) + """ + script = SQLScript(sql, "pinot") + assert len(script.statements) == 1 + assert not script.has_mutation() + + +def test_pinot_date_add_simple() -> None: + """ + Test parsing of simple DATE_ADD expressions. + """ + test_cases = [ + "date_add('day', -180, now())", + "DATE_ADD('month', 5, current_timestamp())", + "date_add('year', 1, my_date_column)", + ] + + for sql in test_cases: + parsed = sqlglot.parse_one(sql, Pinot) + assert parsed is not None + # Verify that it generates valid SQL + generated = parsed.sql(dialect=Pinot) + assert "DATE_ADD" in generated.upper() + + +def test_pinot_date_add_unit_quoted() -> None: + """ + Test that DATE_ADD preserves quotes around the unit argument. + + Pinot requires the unit to be a quoted string, not an identifier. + """ + sql = "dt_epoch_ms >= date_add('day', -180, now())" + result = sqlglot.parse_one(sql, Pinot).sql(Pinot) + + # The unit should be quoted: 'DAY' not DAY + assert "DATE_ADD('DAY', -180, NOW())" in result + assert "DATE_ADD(DAY," not in result