mirror of
https://github.com/apache/superset.git
synced 2026-06-09 17:49:26 +00:00
feat: DB2 dialect for sqlglot (#36365)
This commit is contained in:
@@ -15,8 +15,9 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .db2 import DB2
|
||||
from .dremio import Dremio
|
||||
from .firebolt import Firebolt, FireboltOld
|
||||
from .pinot import Pinot
|
||||
|
||||
__all__ = ["Dremio", "Firebolt", "FireboltOld", "Pinot"]
|
||||
__all__ = ["DB2", "Dremio", "Firebolt", "FireboltOld", "Pinot"]
|
||||
|
||||
148
superset/sql/dialects/db2.py
Normal file
148
superset/sql/dialects/db2.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
IBM DB2 dialect.
|
||||
|
||||
DB2 uses labeled durations for date arithmetic (e.g., expr + 1 DAYS).
|
||||
This syntax is non-standard and requires custom parser support.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, tokens
|
||||
from sqlglot.dialects.postgres import Postgres
|
||||
|
||||
|
||||
class DB2Interval(exp.Expression):
|
||||
"""DB2 labeled duration expression (e.g., '1 DAYS', '2 MONTHS')."""
|
||||
|
||||
arg_types = {"this": True, "unit": True}
|
||||
|
||||
|
||||
class DB2(Postgres):
|
||||
"""
|
||||
IBM DB2 dialect.
|
||||
|
||||
Extends PostgreSQL with support for labeled durations in date arithmetic.
|
||||
"""
|
||||
|
||||
class Tokenizer(Postgres.Tokenizer):
|
||||
"""DB2 SQL tokenizer with support for DB2-specific keywords."""
|
||||
|
||||
KEYWORDS = {
|
||||
**Postgres.Tokenizer.KEYWORDS,
|
||||
# Time units; can follow numbers in date arithmetic
|
||||
"MICROSECOND": tokens.TokenType.VAR,
|
||||
"MICROSECONDS": tokens.TokenType.VAR,
|
||||
"SECOND": tokens.TokenType.VAR,
|
||||
"SECONDS": tokens.TokenType.VAR,
|
||||
"MINUTE": tokens.TokenType.VAR,
|
||||
"MINUTES": tokens.TokenType.VAR,
|
||||
"HOUR": tokens.TokenType.VAR,
|
||||
"HOURS": tokens.TokenType.VAR,
|
||||
"DAY": tokens.TokenType.VAR,
|
||||
"DAYS": tokens.TokenType.VAR,
|
||||
"MONTH": tokens.TokenType.VAR,
|
||||
"MONTHS": tokens.TokenType.VAR,
|
||||
"YEAR": tokens.TokenType.VAR,
|
||||
"YEARS": tokens.TokenType.VAR,
|
||||
}
|
||||
|
||||
class Parser(Postgres.Parser):
|
||||
"""DB2 SQL parser with support for labeled durations."""
|
||||
|
||||
def _parse_term(self) -> exp.Expression | None:
|
||||
"""
|
||||
Override term parsing to support DB2 labeled durations.
|
||||
|
||||
This is called during expression parsing for addition/subtraction
|
||||
operations. We intercept patterns like `expr + 1 DAYS` and parse them
|
||||
specially.
|
||||
"""
|
||||
this = self._parse_factor()
|
||||
if not this:
|
||||
return None
|
||||
|
||||
while self._match_set((tokens.TokenType.PLUS, tokens.TokenType.DASH)):
|
||||
op = self._prev.token_type
|
||||
|
||||
# Parse the right side of the + or -
|
||||
rhs = self._parse_factor()
|
||||
if not rhs: # pragma: no cover
|
||||
break
|
||||
|
||||
# Check if there's a time unit after the right side
|
||||
# This handles patterns like: expr + 1 DAYS, expr + (func()) DAYS
|
||||
if (
|
||||
self._curr
|
||||
and self._curr.token_type == tokens.TokenType.VAR
|
||||
and self._curr.text.upper()
|
||||
in {
|
||||
"MICROSECOND",
|
||||
"MICROSECONDS",
|
||||
"SECOND",
|
||||
"SECONDS",
|
||||
"MINUTE",
|
||||
"MINUTES",
|
||||
"HOUR",
|
||||
"HOURS",
|
||||
"DAY",
|
||||
"DAYS",
|
||||
"MONTH",
|
||||
"MONTHS",
|
||||
"YEAR",
|
||||
"YEARS",
|
||||
}
|
||||
):
|
||||
# Found a DB2 labeled duration
|
||||
unit_token = self._curr
|
||||
self._advance()
|
||||
|
||||
duration = DB2Interval(
|
||||
this=rhs,
|
||||
unit=exp.Literal.string(unit_token.text.upper()),
|
||||
)
|
||||
|
||||
if op == tokens.TokenType.PLUS:
|
||||
this = exp.Add(this=this, expression=duration)
|
||||
else:
|
||||
this = exp.Sub(this=this, expression=duration)
|
||||
else:
|
||||
# Not a labeled duration - use normal Add/Sub
|
||||
if op == tokens.TokenType.PLUS:
|
||||
this = exp.Add(this=this, expression=rhs)
|
||||
else:
|
||||
this = exp.Sub(this=this, expression=rhs)
|
||||
|
||||
return this
|
||||
|
||||
class Generator(Postgres.Generator):
|
||||
"""DB2 SQL generator."""
|
||||
|
||||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS,
|
||||
}
|
||||
|
||||
def db2interval_sql(self, expression: DB2Interval) -> str:
|
||||
"""Generate SQL for DB2Interval expressions."""
|
||||
# Don't quote the unit (DAYS, MONTHS, etc.) - it's a keyword, not a string
|
||||
unit = expression.args["unit"]
|
||||
unit_text = (
|
||||
unit.this if isinstance(unit, exp.Literal) else str(unit).upper()
|
||||
)
|
||||
return f"{self.sql(expression, 'this')} {unit_text}"
|
||||
@@ -45,7 +45,7 @@ from sqlglot.optimizer.scope import (
|
||||
)
|
||||
|
||||
from superset.exceptions import QueryClauseValidationException, SupersetParseError
|
||||
from superset.sql.dialects import Dremio, Firebolt, Pinot
|
||||
from superset.sql.dialects import DB2, Dremio, Firebolt, Pinot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
@@ -67,7 +67,7 @@ SQLGLOT_DIALECTS = {
|
||||
# "crate": ???
|
||||
# "databend": ???
|
||||
"databricks": Dialects.DATABRICKS,
|
||||
# "db2": ???
|
||||
"db2": DB2,
|
||||
# "denodo": ???
|
||||
"dremio": Dremio,
|
||||
"drill": Dialects.DRILL,
|
||||
|
||||
246
tests/unit_tests/sql/dialects/db2_tests.py
Normal file
246
tests/unit_tests/sql/dialects/db2_tests.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import pytest
|
||||
from sqlglot import errors, parse_one
|
||||
|
||||
from superset.sql.dialects.db2 import DB2
|
||||
|
||||
|
||||
def test_month_truncation() -> None:
|
||||
"""
|
||||
Test the month truncation pattern from Db2EngineSpec time grains.
|
||||
"""
|
||||
sql = """
|
||||
SELECT "DATE" - (DAY("DATE")-1) DAYS AS "DATE", sum("TOTAL_FEE") AS "SUM(TOTAL_FEE)"
|
||||
"""
|
||||
|
||||
# test with the generic dialect -- raises exception
|
||||
with pytest.raises(errors.ParseError):
|
||||
parse_one(sql)
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == (
|
||||
'SELECT "DATE" - (DAY("DATE") - 1) DAYS AS "DATE", '
|
||||
'SUM("TOTAL_FEE") AS "SUM(TOTAL_FEE)"'
|
||||
)
|
||||
|
||||
|
||||
def test_labeled_duration_with_day_function() -> None:
|
||||
"""
|
||||
Test labeled duration with DAY function.
|
||||
"""
|
||||
sql = "SELECT CURRENT_DATE - DAY(CURRENT_DATE) DAYS"
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == "SELECT CURRENT_DATE - DAY(CURRENT_DATE) DAYS"
|
||||
|
||||
|
||||
def test_labeled_duration_with_expression() -> None:
|
||||
"""
|
||||
Test labeled duration with complex expressions (from real DB2 queries).
|
||||
"""
|
||||
sql = 'SELECT "DATE" - (DAY("DATE") - 1) DAYS'
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == 'SELECT "DATE" - (DAY("DATE") - 1) DAYS'
|
||||
|
||||
|
||||
def test_labeled_duration_with_month_function() -> None:
|
||||
"""
|
||||
Test labeled duration with MONTH function.
|
||||
"""
|
||||
sql = 'SELECT "DATE" - (MONTH("DATE") - 1) MONTHS'
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == 'SELECT "DATE" - (MONTH("DATE") - 1) MONTHS'
|
||||
|
||||
|
||||
def test_year_truncation() -> None:
|
||||
"""
|
||||
Test the year truncation pattern from Db2EngineSpec time grains.
|
||||
"""
|
||||
sql = 'SELECT "DATE" - (DAY("DATE")-1) DAYS - (MONTH("DATE")-1) MONTHS'
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == (
|
||||
'SELECT "DATE" - (DAY("DATE") - 1) DAYS - (MONTH("DATE") - 1) MONTHS'
|
||||
)
|
||||
|
||||
|
||||
def test_quarter_truncation() -> None:
|
||||
"""
|
||||
Test the quarter truncation pattern from Db2EngineSpec time grains.
|
||||
"""
|
||||
sql = (
|
||||
'SELECT "DATE" - (DAY("DATE")-1) DAYS - (MONTH("DATE")-1) MONTHS'
|
||||
' + ((QUARTER("DATE")-1) * 3) MONTHS'
|
||||
)
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == (
|
||||
'SELECT "DATE" - (DAY("DATE") - 1) DAYS - (MONTH("DATE") - 1) MONTHS'
|
||||
' + ((QUARTER("DATE") - 1) * 3) MONTHS'
|
||||
)
|
||||
|
||||
|
||||
def test_regular_column_aliasing_still_works() -> None:
|
||||
"""
|
||||
Test that regular column aliasing still works (regression test).
|
||||
"""
|
||||
sql = "SELECT col1 AS days FROM table"
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == "SELECT col1 AS days FROM table"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, expected",
|
||||
[
|
||||
("SELECT col1 AS day", "SELECT col1 AS day"),
|
||||
("SELECT col1 AS month", "SELECT col1 AS month"),
|
||||
("SELECT col1 AS year", "SELECT col1 AS year"),
|
||||
("SELECT col1 AS days", "SELECT col1 AS days"),
|
||||
("SELECT col1 AS months", "SELECT col1 AS months"),
|
||||
("SELECT col1 AS years", "SELECT col1 AS years"),
|
||||
],
|
||||
)
|
||||
def test_column_aliasing_with_reserved_words(sql: str, expected: str) -> None:
|
||||
"""
|
||||
Test column aliasing with DB2 time unit words.
|
||||
"""
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
assert regenerated == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, expected",
|
||||
[
|
||||
# Function-based patterns
|
||||
('SELECT "DATE" - DAY("DATE") DAYS', 'SELECT "DATE" - DAY("DATE") DAYS'),
|
||||
(
|
||||
'SELECT "DATE" + MONTH("DATE") MONTHS',
|
||||
'SELECT "DATE" + MONTH("DATE") MONTHS',
|
||||
),
|
||||
('SELECT "DATE" - YEAR("DATE") YEARS', 'SELECT "DATE" - YEAR("DATE") YEARS'),
|
||||
# Complex expression patterns
|
||||
(
|
||||
'SELECT "DATE" - (DAY("DATE") - 1) DAYS',
|
||||
'SELECT "DATE" - (DAY("DATE") - 1) DAYS',
|
||||
),
|
||||
(
|
||||
'SELECT "DATE" + (MONTH("DATE") + 2) MONTHS',
|
||||
'SELECT "DATE" + (MONTH("DATE") + 2) MONTHS',
|
||||
),
|
||||
# Nested expressions
|
||||
(
|
||||
'SELECT "DATE" - ((DAY("DATE") - 1) + 1) DAYS - (MONTH("DATE") - 1) MONTHS',
|
||||
'SELECT "DATE" - ((DAY("DATE") - 1) + 1) DAYS - (MONTH("DATE") - 1) MONTHS',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_labeled_duration_variations(sql: str, expected: str) -> None:
|
||||
"""
|
||||
Test various labeled duration patterns that should work.
|
||||
"""
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
assert regenerated == expected
|
||||
|
||||
|
||||
def test_addition_with_labeled_duration() -> None:
|
||||
"""
|
||||
Test addition operations with labeled durations.
|
||||
"""
|
||||
sql = 'SELECT "DATE" + (DAY("DATE") + 5) DAYS'
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == 'SELECT "DATE" + (DAY("DATE") + 5) DAYS'
|
||||
|
||||
|
||||
def test_arithmetic_with_different_units() -> None:
|
||||
"""
|
||||
Test arithmetic operations mixing different time units.
|
||||
"""
|
||||
sql = (
|
||||
'SELECT "DATE" - (DAY("DATE")-1) DAYS '
|
||||
'- (MONTH("DATE")-1) MONTHS + '
|
||||
'(YEAR("DATE")-1) YEARS'
|
||||
)
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == (
|
||||
'SELECT "DATE" - (DAY("DATE") - 1) DAYS - '
|
||||
'(MONTH("DATE") - 1) MONTHS + '
|
||||
'(YEAR("DATE") - 1) YEARS'
|
||||
)
|
||||
|
||||
|
||||
def test_multiple_function_calls_in_duration() -> None:
|
||||
"""
|
||||
Test labeled duration with multiple function calls.
|
||||
"""
|
||||
sql = 'SELECT "DATE" - (DAY("DATE") + MONTH("DATE")) DAYS'
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == 'SELECT "DATE" - (DAY("DATE") + MONTH("DATE")) DAYS'
|
||||
|
||||
|
||||
def test_labeled_duration_with_multiplication() -> None:
|
||||
"""
|
||||
Test labeled duration with multiplication in the expression.
|
||||
"""
|
||||
sql = 'SELECT "DATE" + ((QUARTER("DATE") - 1) * 3) MONTHS'
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
assert regenerated == 'SELECT "DATE" + ((QUARTER("DATE") - 1) * 3) MONTHS'
|
||||
|
||||
|
||||
def test_column_plus_literal_duration() -> None:
|
||||
"""
|
||||
Test column + literal number with time unit.
|
||||
"""
|
||||
sql = "SELECT col + 1 DAYS FROM t"
|
||||
|
||||
ast = parse_one(sql, dialect=DB2)
|
||||
regenerated = ast.sql(dialect=DB2)
|
||||
|
||||
# Should parse as (col + 1 DAYS), not (col + 1) AS DAYS
|
||||
assert regenerated == "SELECT col + 1 DAYS FROM t"
|
||||
Reference in New Issue
Block a user