feat: use sqlglot to set limit (#33473)

This commit is contained in:
Beto Dealmeida
2025-05-27 15:20:02 -04:00
committed by GitHub
parent cc8ab2c556
commit 8de58b9848
34 changed files with 573 additions and 557 deletions

View File

@@ -15,10 +15,10 @@
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.ascend import AscendEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
class TestAscendDbEngineSpec(TestDbEngineSpec):
class TestAscendDbEngineSpec(SupersetTestCase):
def test_convert_dttm(self):
dttm = self.get_dttm()

View File

@@ -25,14 +25,13 @@ from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
builtin_time_grains,
LimitMethod,
)
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.test_app import app
from ..fixtures.birth_names_dashboard import (
@@ -46,7 +45,7 @@ from ..fixtures.energy_dashboard import (
from ..fixtures.pyodbcRow import Row
class TestDbEngineSpecs(TestDbEngineSpec):
class SupersetTestCases(SupersetTestCase):
def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec):
q0 = "select * from table"
q1 = "select * from mytable limit 10"
@@ -74,124 +73,9 @@ class TestDbEngineSpecs(TestDbEngineSpec):
assert engine_spec_class.get_limit_from_sql(q10) is None
assert engine_spec_class.get_limit_from_sql(q11) is None
def test_wrapped_semi_tabs(self):
self.sql_limit_regex(
"SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
)
def test_simple_limit_query(self):
self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
def test_modify_limit_query(self):
self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name
self.sql_limit_regex(
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
)
def test_limit_query_without_force(self):
self.sql_limit_regex(
"SELECT * FROM a LIMIT 10",
"SELECT * FROM a LIMIT 10",
limit=11,
)
def test_limit_query_with_force(self):
self.sql_limit_regex(
"SELECT * FROM a LIMIT 10",
"SELECT * FROM a LIMIT 11",
limit=11,
force=True,
)
def test_limit_with_expr(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000""",
)
def test_limit_expr_and_semicolon(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990 ;""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000""",
)
def test_get_datatype(self):
assert "VARCHAR" == BaseEngineSpec.get_datatype("VARCHAR")
def test_limit_with_implicit_offset(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990, 999999""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990, 1000""",
)
def test_limit_with_explicit_offset(self):
self.sql_limit_regex(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990
OFFSET 999999""",
"""SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 1000
OFFSET 999999""",
)
def test_limit_with_non_token_limit(self):
self.sql_limit_regex(
"""SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
)
def test_limit_with_fetch_many(self):
class DummyEngineSpec(BaseEngineSpec):
limit_method = LimitMethod.FETCH_MANY
self.sql_limit_regex(
"SELECT * FROM table", "SELECT * FROM table", DummyEngineSpec
)
def test_engine_time_grain_validity(self):
time_grains = set(builtin_time_grains.keys())
# loop over all subclasses of BaseEngineSpec

View File

@@ -1,36 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from tests.integration_tests.test_app import app # noqa: F401
from tests.integration_tests.base_tests import SupersetTestCase
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import Database
class TestDbEngineSpec(SupersetTestCase):
def sql_limit_regex(
self,
sql,
expected_sql,
engine_spec_class=BaseEngineSpec,
limit=1000,
force=False,
):
main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force)
assert expected_sql == limited

View File

@@ -26,7 +26,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import Table
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
load_birth_names_data, # noqa: F401
@@ -42,7 +42,7 @@ def mock_engine_with_credentials(*args, **kwargs):
yield engine_mock
class TestBigQueryDbEngineSpec(TestDbEngineSpec):
class TestBigQueryDbEngineSpec(SupersetTestCase):
def test_bigquery_sqla_column_label(self):
"""
DB Eng Specs (bigquery): Test column label

View File

@@ -18,12 +18,12 @@ from unittest import mock
from superset.db_engine_specs import get_engine_spec
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
class TestDatabricksDbEngineSpec(TestDbEngineSpec):
class TestDatabricksDbEngineSpec(SupersetTestCase):
def test_get_engine_spec(self):
"""
DB Eng Specs (databricks): Test "databricks" in engine spec

View File

@@ -19,10 +19,10 @@ from sqlalchemy import column
from superset.constants import TimeGrain
from superset.db_engine_specs.elasticsearch import ElasticSearchEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
class TestElasticsearchDbEngineSpec(TestDbEngineSpec):
class TestElasticsearchDbEngineSpec(SupersetTestCase):
@parameterized.expand(
[
[TimeGrain.SECOND, "DATE_TRUNC('second', ts)"],

View File

@@ -16,10 +16,10 @@
# under the License.
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
class TestGsheetsDbEngineSpec(TestDbEngineSpec):
class TestGsheetsDbEngineSpec(SupersetTestCase):
def test_extract_errors(self):
"""
Test that custom error messages are extracted correctly.

View File

@@ -17,7 +17,6 @@
# isort:skip_file
from unittest import mock
import unittest
from .base_tests import SupersetTestCase
import pytest
import pandas as pd
@@ -26,6 +25,7 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
from superset.sql_parse import Table
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.test_app import app

View File

@@ -21,12 +21,12 @@ from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
class TestMySQLEngineSpecsDbEngineSpec(SupersetTestCase):
@unittest.skipUnless(
TestDbEngineSpec.is_module_installed("MySQLdb"), "mysqlclient not installed"
SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
)
def test_get_datatype_mysql(self):
"""Tests related to datatype mapping for MySQL"""

View File

@@ -17,10 +17,10 @@
from sqlalchemy import column
from superset.db_engine_specs.pinot import PinotEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
class TestPinotDbEngineSpec(TestDbEngineSpec):
class TestPinotDbEngineSpec(SupersetTestCase):
"""Tests pertaining to our Pinot database support"""
def test_pinot_time_expression_sec_one_1d_grain(self):

View File

@@ -27,12 +27,12 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils.core import backend
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
class TestPostgresDbEngineSpec(TestDbEngineSpec):
class TestPostgresDbEngineSpec(SupersetTestCase):
def test_get_table_names(self):
"""
DB Eng Specs (postgres): Test get table names

View File

@@ -27,11 +27,11 @@ from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
class TestPrestoDbEngineSpec(TestDbEngineSpec):
@skipUnless(TestDbEngineSpec.is_module_installed("pyhive"), "pyhive not installed")
class TestPrestoDbEngineSpec(SupersetTestCase):
@skipUnless(SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed")
def test_get_datatype_presto(self):
assert "STRING" == PrestoEngineSpec.get_datatype("string")

View File

@@ -24,11 +24,11 @@ from sqlalchemy.types import NVARCHAR
from superset.db_engine_specs.redshift import RedshiftEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import Table
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.test_app import app
class TestRedshiftDbEngineSpec(TestDbEngineSpec):
class TestRedshiftDbEngineSpec(SupersetTestCase):
def test_extract_errors(self):
"""
Test that custom error messages are extracted correctly.