diff --git a/superset/sql/dialects/__init__.py b/superset/sql/dialects/__init__.py index f4d56e17e6d..71c8958a80f 100644 --- a/superset/sql/dialects/__init__.py +++ b/superset/sql/dialects/__init__.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. +from .db2 import DB2 from .dremio import Dremio from .firebolt import Firebolt, FireboltOld from .pinot import Pinot -__all__ = ["Dremio", "Firebolt", "FireboltOld", "Pinot"] +__all__ = ["DB2", "Dremio", "Firebolt", "FireboltOld", "Pinot"] diff --git a/superset/sql/dialects/db2.py b/superset/sql/dialects/db2.py new file mode 100644 index 00000000000..4f70543be33 --- /dev/null +++ b/superset/sql/dialects/db2.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +IBM DB2 dialect. + +DB2 uses labeled durations for date arithmetic (e.g., expr + 1 DAYS). +This syntax is non-standard and requires custom parser support. +""" + +from __future__ import annotations + +from sqlglot import exp, tokens +from sqlglot.dialects.postgres import Postgres + + +class DB2Interval(exp.Expression): + """DB2 labeled duration expression (e.g., '1 DAYS', '2 MONTHS').""" + + arg_types = {"this": True, "unit": True} + + +class DB2(Postgres): + """ + IBM DB2 dialect. + + Extends PostgreSQL with support for labeled durations in date arithmetic. + """ + + class Tokenizer(Postgres.Tokenizer): + """DB2 SQL tokenizer with support for DB2-specific keywords.""" + + KEYWORDS = { + **Postgres.Tokenizer.KEYWORDS, + # Time units; can follow numbers in date arithmetic + "MICROSECOND": tokens.TokenType.VAR, + "MICROSECONDS": tokens.TokenType.VAR, + "SECOND": tokens.TokenType.VAR, + "SECONDS": tokens.TokenType.VAR, + "MINUTE": tokens.TokenType.VAR, + "MINUTES": tokens.TokenType.VAR, + "HOUR": tokens.TokenType.VAR, + "HOURS": tokens.TokenType.VAR, + "DAY": tokens.TokenType.VAR, + "DAYS": tokens.TokenType.VAR, + "MONTH": tokens.TokenType.VAR, + "MONTHS": tokens.TokenType.VAR, + "YEAR": tokens.TokenType.VAR, + "YEARS": tokens.TokenType.VAR, + } + + class Parser(Postgres.Parser): + """DB2 SQL parser with support for labeled durations.""" + + def _parse_term(self) -> exp.Expression | None: + """ + Override term parsing to support DB2 labeled durations. + + This is called during expression parsing for addition/subtraction + operations. We intercept patterns like `expr + 1 DAYS` and parse them + specially. + """ + this = self._parse_factor() + if not this: + return None + + while self._match_set((tokens.TokenType.PLUS, tokens.TokenType.DASH)): + op = self._prev.token_type + + # Parse the right side of the + or - + rhs = self._parse_factor() + if not rhs: # pragma: no cover + break + + # Check if there's a time unit after the right side + # This handles patterns like: expr + 1 DAYS, expr + (func()) DAYS + if ( + self._curr + and self._curr.token_type == tokens.TokenType.VAR + and self._curr.text.upper() + in { + "MICROSECOND", + "MICROSECONDS", + "SECOND", + "SECONDS", + "MINUTE", + "MINUTES", + "HOUR", + "HOURS", + "DAY", + "DAYS", + "MONTH", + "MONTHS", + "YEAR", + "YEARS", + } + ): + # Found a DB2 labeled duration + unit_token = self._curr + self._advance() + + duration = DB2Interval( + this=rhs, + unit=exp.Literal.string(unit_token.text.upper()), + ) + + if op == tokens.TokenType.PLUS: + this = exp.Add(this=this, expression=duration) + else: + this = exp.Sub(this=this, expression=duration) + else: + # Not a labeled duration - use normal Add/Sub + if op == tokens.TokenType.PLUS: + this = exp.Add(this=this, expression=rhs) + else: + this = exp.Sub(this=this, expression=rhs) + + return this + + class Generator(Postgres.Generator): + """DB2 SQL generator.""" + + TRANSFORMS = { + **Postgres.Generator.TRANSFORMS, + } + + def db2interval_sql(self, expression: DB2Interval) -> str: + """Generate SQL for DB2Interval expressions.""" + # Don't quote the unit (DAYS, MONTHS, etc.) - it's a keyword, not a string + unit = expression.args["unit"] + unit_text = ( + unit.this if isinstance(unit, exp.Literal) else str(unit).upper() + ) + return f"{self.sql(expression, 'this')} {unit_text}" diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 4d4a72d293c..ba4d288a972 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -45,7 +45,7 @@ from sqlglot.optimizer.scope import ( ) from superset.exceptions import QueryClauseValidationException, SupersetParseError -from superset.sql.dialects import Dremio, Firebolt, Pinot +from superset.sql.dialects import DB2, Dremio, Firebolt, Pinot if TYPE_CHECKING: from superset.models.core import Database @@ -67,7 +67,7 @@ SQLGLOT_DIALECTS = { # "crate": ??? # "databend": ??? "databricks": Dialects.DATABRICKS, - # "db2": ??? + "db2": DB2, # "denodo": ??? "dremio": Dremio, "drill": Dialects.DRILL, diff --git a/tests/unit_tests/sql/dialects/db2_tests.py b/tests/unit_tests/sql/dialects/db2_tests.py new file mode 100644 index 00000000000..b8add1d65b1 --- /dev/null +++ b/tests/unit_tests/sql/dialects/db2_tests.py @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from sqlglot import errors, parse_one + +from superset.sql.dialects.db2 import DB2 + + +def test_month_truncation() -> None: + """ + Test the month truncation pattern from Db2EngineSpec time grains. + """ + sql = """ +SELECT "DATE" - (DAY("DATE")-1) DAYS AS "DATE", sum("TOTAL_FEE") AS "SUM(TOTAL_FEE)" + """ + + # test with the generic dialect -- raises exception + with pytest.raises(errors.ParseError): + parse_one(sql) + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == ( + 'SELECT "DATE" - (DAY("DATE") - 1) DAYS AS "DATE", ' + 'SUM("TOTAL_FEE") AS "SUM(TOTAL_FEE)"' + ) + + +def test_labeled_duration_with_day_function() -> None: + """ + Test labeled duration with DAY function. + """ + sql = "SELECT CURRENT_DATE - DAY(CURRENT_DATE) DAYS" + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == "SELECT CURRENT_DATE - DAY(CURRENT_DATE) DAYS" + + +def test_labeled_duration_with_expression() -> None: + """ + Test labeled duration with complex expressions (from real DB2 queries). + """ + sql = 'SELECT "DATE" - (DAY("DATE") - 1) DAYS' + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == 'SELECT "DATE" - (DAY("DATE") - 1) DAYS' + + +def test_labeled_duration_with_month_function() -> None: + """ + Test labeled duration with MONTH function. + """ + sql = 'SELECT "DATE" - (MONTH("DATE") - 1) MONTHS' + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == 'SELECT "DATE" - (MONTH("DATE") - 1) MONTHS' + + +def test_year_truncation() -> None: + """ + Test the year truncation pattern from Db2EngineSpec time grains. + """ + sql = 'SELECT "DATE" - (DAY("DATE")-1) DAYS - (MONTH("DATE")-1) MONTHS' + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == ( + 'SELECT "DATE" - (DAY("DATE") - 1) DAYS - (MONTH("DATE") - 1) MONTHS' + ) + + +def test_quarter_truncation() -> None: + """ + Test the quarter truncation pattern from Db2EngineSpec time grains. + """ + sql = ( + 'SELECT "DATE" - (DAY("DATE")-1) DAYS - (MONTH("DATE")-1) MONTHS' + ' + ((QUARTER("DATE")-1) * 3) MONTHS' + ) + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == ( + 'SELECT "DATE" - (DAY("DATE") - 1) DAYS - (MONTH("DATE") - 1) MONTHS' + ' + ((QUARTER("DATE") - 1) * 3) MONTHS' + ) + + +def test_regular_column_aliasing_still_works() -> None: + """ + Test that regular column aliasing still works (regression test). + """ + sql = "SELECT col1 AS days FROM table" + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == "SELECT col1 AS days FROM table" + + +@pytest.mark.parametrize( + "sql, expected", + [ + ("SELECT col1 AS day", "SELECT col1 AS day"), + ("SELECT col1 AS month", "SELECT col1 AS month"), + ("SELECT col1 AS year", "SELECT col1 AS year"), + ("SELECT col1 AS days", "SELECT col1 AS days"), + ("SELECT col1 AS months", "SELECT col1 AS months"), + ("SELECT col1 AS years", "SELECT col1 AS years"), + ], +) +def test_column_aliasing_with_reserved_words(sql: str, expected: str) -> None: + """ + Test column aliasing with DB2 time unit words. + """ + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + assert regenerated == expected + + +@pytest.mark.parametrize( + "sql, expected", + [ + # Function-based patterns + ('SELECT "DATE" - DAY("DATE") DAYS', 'SELECT "DATE" - DAY("DATE") DAYS'), + ( + 'SELECT "DATE" + MONTH("DATE") MONTHS', + 'SELECT "DATE" + MONTH("DATE") MONTHS', + ), + ('SELECT "DATE" - YEAR("DATE") YEARS', 'SELECT "DATE" - YEAR("DATE") YEARS'), + # Complex expression patterns + ( + 'SELECT "DATE" - (DAY("DATE") - 1) DAYS', + 'SELECT "DATE" - (DAY("DATE") - 1) DAYS', + ), + ( + 'SELECT "DATE" + (MONTH("DATE") + 2) MONTHS', + 'SELECT "DATE" + (MONTH("DATE") + 2) MONTHS', + ), + # Nested expressions + ( + 'SELECT "DATE" - ((DAY("DATE") - 1) + 1) DAYS - (MONTH("DATE") - 1) MONTHS', + 'SELECT "DATE" - ((DAY("DATE") - 1) + 1) DAYS - (MONTH("DATE") - 1) MONTHS', + ), + ], +) +def test_labeled_duration_variations(sql: str, expected: str) -> None: + """ + Test various labeled duration patterns that should work. + """ + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + assert regenerated == expected + + +def test_addition_with_labeled_duration() -> None: + """ + Test addition operations with labeled durations. + """ + sql = 'SELECT "DATE" + (DAY("DATE") + 5) DAYS' + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == 'SELECT "DATE" + (DAY("DATE") + 5) DAYS' + + +def test_arithmetic_with_different_units() -> None: + """ + Test arithmetic operations mixing different time units. + """ + sql = ( + 'SELECT "DATE" - (DAY("DATE")-1) DAYS ' + '- (MONTH("DATE")-1) MONTHS + ' + '(YEAR("DATE")-1) YEARS' + ) + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == ( + 'SELECT "DATE" - (DAY("DATE") - 1) DAYS - ' + '(MONTH("DATE") - 1) MONTHS + ' + '(YEAR("DATE") - 1) YEARS' + ) + + +def test_multiple_function_calls_in_duration() -> None: + """ + Test labeled duration with multiple function calls. + """ + sql = 'SELECT "DATE" - (DAY("DATE") + MONTH("DATE")) DAYS' + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == 'SELECT "DATE" - (DAY("DATE") + MONTH("DATE")) DAYS' + + +def test_labeled_duration_with_multiplication() -> None: + """ + Test labeled duration with multiplication in the expression. + """ + sql = 'SELECT "DATE" + ((QUARTER("DATE") - 1) * 3) MONTHS' + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + assert regenerated == 'SELECT "DATE" + ((QUARTER("DATE") - 1) * 3) MONTHS' + + +def test_column_plus_literal_duration() -> None: + """ + Test column + literal number with time unit. + """ + sql = "SELECT col + 1 DAYS FROM t" + + ast = parse_one(sql, dialect=DB2) + regenerated = ast.sql(dialect=DB2) + + # Should parse as (col + 1 DAYS), not (col + 1) AS DAYS + assert regenerated == "SELECT col + 1 DAYS FROM t"