feat: DB2 dialect for sqlglot (#36365)

This commit is contained in:
Beto Dealmeida
2025-12-02 12:19:52 -05:00
committed by GitHub
parent 005e4e3ea8
commit e4cb84bc02
4 changed files with 398 additions and 3 deletions

View File

@@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.
from .db2 import DB2
from .dremio import Dremio
from .firebolt import Firebolt, FireboltOld
from .pinot import Pinot
__all__ = ["Dremio", "Firebolt", "FireboltOld", "Pinot"]
__all__ = ["DB2", "Dremio", "Firebolt", "FireboltOld", "Pinot"]

View File

@@ -0,0 +1,148 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
IBM DB2 dialect.
DB2 uses labeled durations for date arithmetic (e.g., expr + 1 DAYS).
This syntax is non-standard and requires custom parser support.
"""
from __future__ import annotations
from sqlglot import exp, tokens
from sqlglot.dialects.postgres import Postgres
class DB2Interval(exp.Expression):
"""DB2 labeled duration expression (e.g., '1 DAYS', '2 MONTHS')."""
arg_types = {"this": True, "unit": True}
class DB2(Postgres):
"""
IBM DB2 dialect.
Extends PostgreSQL with support for labeled durations in date arithmetic.
"""
class Tokenizer(Postgres.Tokenizer):
"""DB2 SQL tokenizer with support for DB2-specific keywords."""
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS,
# Time units; can follow numbers in date arithmetic
"MICROSECOND": tokens.TokenType.VAR,
"MICROSECONDS": tokens.TokenType.VAR,
"SECOND": tokens.TokenType.VAR,
"SECONDS": tokens.TokenType.VAR,
"MINUTE": tokens.TokenType.VAR,
"MINUTES": tokens.TokenType.VAR,
"HOUR": tokens.TokenType.VAR,
"HOURS": tokens.TokenType.VAR,
"DAY": tokens.TokenType.VAR,
"DAYS": tokens.TokenType.VAR,
"MONTH": tokens.TokenType.VAR,
"MONTHS": tokens.TokenType.VAR,
"YEAR": tokens.TokenType.VAR,
"YEARS": tokens.TokenType.VAR,
}
class Parser(Postgres.Parser):
"""DB2 SQL parser with support for labeled durations."""
def _parse_term(self) -> exp.Expression | None:
"""
Override term parsing to support DB2 labeled durations.
This is called during expression parsing for addition/subtraction
operations. We intercept patterns like `expr + 1 DAYS` and parse them
specially.
"""
this = self._parse_factor()
if not this:
return None
while self._match_set((tokens.TokenType.PLUS, tokens.TokenType.DASH)):
op = self._prev.token_type
# Parse the right side of the + or -
rhs = self._parse_factor()
if not rhs: # pragma: no cover
break
# Check if there's a time unit after the right side
# This handles patterns like: expr + 1 DAYS, expr + (func()) DAYS
if (
self._curr
and self._curr.token_type == tokens.TokenType.VAR
and self._curr.text.upper()
in {
"MICROSECOND",
"MICROSECONDS",
"SECOND",
"SECONDS",
"MINUTE",
"MINUTES",
"HOUR",
"HOURS",
"DAY",
"DAYS",
"MONTH",
"MONTHS",
"YEAR",
"YEARS",
}
):
# Found a DB2 labeled duration
unit_token = self._curr
self._advance()
duration = DB2Interval(
this=rhs,
unit=exp.Literal.string(unit_token.text.upper()),
)
if op == tokens.TokenType.PLUS:
this = exp.Add(this=this, expression=duration)
else:
this = exp.Sub(this=this, expression=duration)
else:
# Not a labeled duration - use normal Add/Sub
if op == tokens.TokenType.PLUS:
this = exp.Add(this=this, expression=rhs)
else:
this = exp.Sub(this=this, expression=rhs)
return this
class Generator(Postgres.Generator):
"""DB2 SQL generator."""
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS,
}
def db2interval_sql(self, expression: DB2Interval) -> str:
"""Generate SQL for DB2Interval expressions."""
# Don't quote the unit (DAYS, MONTHS, etc.) - it's a keyword, not a string
unit = expression.args["unit"]
unit_text = (
unit.this if isinstance(unit, exp.Literal) else str(unit).upper()
)
return f"{self.sql(expression, 'this')} {unit_text}"

View File

@@ -45,7 +45,7 @@ from sqlglot.optimizer.scope import (
)
from superset.exceptions import QueryClauseValidationException, SupersetParseError
from superset.sql.dialects import Dremio, Firebolt, Pinot
from superset.sql.dialects import DB2, Dremio, Firebolt, Pinot
if TYPE_CHECKING:
from superset.models.core import Database
@@ -67,7 +67,7 @@ SQLGLOT_DIALECTS = {
# "crate": ???
# "databend": ???
"databricks": Dialects.DATABRICKS,
# "db2": ???
"db2": DB2,
# "denodo": ???
"dremio": Dremio,
"drill": Dialects.DRILL,

View File

@@ -0,0 +1,246 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from sqlglot import errors, parse_one
from superset.sql.dialects.db2 import DB2
def test_month_truncation() -> None:
"""
Test the month truncation pattern from Db2EngineSpec time grains.
"""
sql = """
SELECT "DATE" - (DAY("DATE")-1) DAYS AS "DATE", sum("TOTAL_FEE") AS "SUM(TOTAL_FEE)"
"""
# test with the generic dialect -- raises exception
with pytest.raises(errors.ParseError):
parse_one(sql)
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == (
'SELECT "DATE" - (DAY("DATE") - 1) DAYS AS "DATE", '
'SUM("TOTAL_FEE") AS "SUM(TOTAL_FEE)"'
)
def test_labeled_duration_with_day_function() -> None:
"""
Test labeled duration with DAY function.
"""
sql = "SELECT CURRENT_DATE - DAY(CURRENT_DATE) DAYS"
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == "SELECT CURRENT_DATE - DAY(CURRENT_DATE) DAYS"
def test_labeled_duration_with_expression() -> None:
"""
Test labeled duration with complex expressions (from real DB2 queries).
"""
sql = 'SELECT "DATE" - (DAY("DATE") - 1) DAYS'
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == 'SELECT "DATE" - (DAY("DATE") - 1) DAYS'
def test_labeled_duration_with_month_function() -> None:
"""
Test labeled duration with MONTH function.
"""
sql = 'SELECT "DATE" - (MONTH("DATE") - 1) MONTHS'
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == 'SELECT "DATE" - (MONTH("DATE") - 1) MONTHS'
def test_year_truncation() -> None:
"""
Test the year truncation pattern from Db2EngineSpec time grains.
"""
sql = 'SELECT "DATE" - (DAY("DATE")-1) DAYS - (MONTH("DATE")-1) MONTHS'
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == (
'SELECT "DATE" - (DAY("DATE") - 1) DAYS - (MONTH("DATE") - 1) MONTHS'
)
def test_quarter_truncation() -> None:
"""
Test the quarter truncation pattern from Db2EngineSpec time grains.
"""
sql = (
'SELECT "DATE" - (DAY("DATE")-1) DAYS - (MONTH("DATE")-1) MONTHS'
' + ((QUARTER("DATE")-1) * 3) MONTHS'
)
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == (
'SELECT "DATE" - (DAY("DATE") - 1) DAYS - (MONTH("DATE") - 1) MONTHS'
' + ((QUARTER("DATE") - 1) * 3) MONTHS'
)
def test_regular_column_aliasing_still_works() -> None:
"""
Test that regular column aliasing still works (regression test).
"""
sql = "SELECT col1 AS days FROM table"
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == "SELECT col1 AS days FROM table"
@pytest.mark.parametrize(
"sql, expected",
[
("SELECT col1 AS day", "SELECT col1 AS day"),
("SELECT col1 AS month", "SELECT col1 AS month"),
("SELECT col1 AS year", "SELECT col1 AS year"),
("SELECT col1 AS days", "SELECT col1 AS days"),
("SELECT col1 AS months", "SELECT col1 AS months"),
("SELECT col1 AS years", "SELECT col1 AS years"),
],
)
def test_column_aliasing_with_reserved_words(sql: str, expected: str) -> None:
"""
Test column aliasing with DB2 time unit words.
"""
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == expected
@pytest.mark.parametrize(
"sql, expected",
[
# Function-based patterns
('SELECT "DATE" - DAY("DATE") DAYS', 'SELECT "DATE" - DAY("DATE") DAYS'),
(
'SELECT "DATE" + MONTH("DATE") MONTHS',
'SELECT "DATE" + MONTH("DATE") MONTHS',
),
('SELECT "DATE" - YEAR("DATE") YEARS', 'SELECT "DATE" - YEAR("DATE") YEARS'),
# Complex expression patterns
(
'SELECT "DATE" - (DAY("DATE") - 1) DAYS',
'SELECT "DATE" - (DAY("DATE") - 1) DAYS',
),
(
'SELECT "DATE" + (MONTH("DATE") + 2) MONTHS',
'SELECT "DATE" + (MONTH("DATE") + 2) MONTHS',
),
# Nested expressions
(
'SELECT "DATE" - ((DAY("DATE") - 1) + 1) DAYS - (MONTH("DATE") - 1) MONTHS',
'SELECT "DATE" - ((DAY("DATE") - 1) + 1) DAYS - (MONTH("DATE") - 1) MONTHS',
),
],
)
def test_labeled_duration_variations(sql: str, expected: str) -> None:
"""
Test various labeled duration patterns that should work.
"""
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == expected
def test_addition_with_labeled_duration() -> None:
"""
Test addition operations with labeled durations.
"""
sql = 'SELECT "DATE" + (DAY("DATE") + 5) DAYS'
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == 'SELECT "DATE" + (DAY("DATE") + 5) DAYS'
def test_arithmetic_with_different_units() -> None:
"""
Test arithmetic operations mixing different time units.
"""
sql = (
'SELECT "DATE" - (DAY("DATE")-1) DAYS '
'- (MONTH("DATE")-1) MONTHS + '
'(YEAR("DATE")-1) YEARS'
)
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == (
'SELECT "DATE" - (DAY("DATE") - 1) DAYS - '
'(MONTH("DATE") - 1) MONTHS + '
'(YEAR("DATE") - 1) YEARS'
)
def test_multiple_function_calls_in_duration() -> None:
"""
Test labeled duration with multiple function calls.
"""
sql = 'SELECT "DATE" - (DAY("DATE") + MONTH("DATE")) DAYS'
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == 'SELECT "DATE" - (DAY("DATE") + MONTH("DATE")) DAYS'
def test_labeled_duration_with_multiplication() -> None:
"""
Test labeled duration with multiplication in the expression.
"""
sql = 'SELECT "DATE" + ((QUARTER("DATE") - 1) * 3) MONTHS'
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
assert regenerated == 'SELECT "DATE" + ((QUARTER("DATE") - 1) * 3) MONTHS'
def test_column_plus_literal_duration() -> None:
"""
Test column + literal number with time unit.
"""
sql = "SELECT col + 1 DAYS FROM t"
ast = parse_one(sql, dialect=DB2)
regenerated = ast.sql(dialect=DB2)
# Should parse as (col + 1 DAYS), not (col + 1) AS DAYS
assert regenerated == "SELECT col + 1 DAYS FROM t"