diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 46786fdf659..e3ffdd335d8 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1546,7 +1546,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cols: list[ResultSetColumnType] | None = None, ) -> str: """ - Generate a "SELECT * from [schema.]table_name" query with appropriate limit. + Generate a "SELECT * from [catalog.][schema.]table_name" query with limit. WARNING: expects only unquoted table and schema names. @@ -1560,6 +1560,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param cols: Columns to include in query :return: SQL query """ + if not cls.supports_cross_catalog_queries: + table = Table(table.table, table.schema, None) + # pylint: disable=redefined-outer-name fields: str | list[Any] = "*" cols = cols or [] diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 2297e2572b3..34ac2aa0ac8 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -498,6 +498,9 @@ class HiveEngineSpec(PrestoEngineSpec): latest_partition: bool = True, cols: list[ResultSetColumnType] | None = None, ) -> str: + # remove catalog from table name if it exists + table = Table(table.table, table.schema, None) + return super(PrestoEngineSpec, cls).select_star( database, table, diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 5ccbc7d627c..95261ba0469 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -206,7 +206,7 @@ def test_select_star(mocker: MockerFixture) -> None: sql = BaseEngineSpec.select_star( database=database, - table=Table("my_table"), + table=Table("my_table", "my_schema", "my_catalog"), engine=engine, limit=100, show_cols=True, @@ -214,7 +214,7 @@ def test_select_star(mocker: MockerFixture) -> None: latest_partition=False, cols=cols, ) - assert sql == "SELECT\n a\nFROM my_table\nLIMIT ?\nOFFSET ?" + assert sql == "SELECT\n a\nFROM my_schema.my_table\nLIMIT ?\nOFFSET ?" def test_extra_table_metadata(mocker: MockerFixture) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_hive.py b/tests/unit_tests/db_engine_specs/test_hive.py index e534259d2c1..2bbb979782f 100644 --- a/tests/unit_tests/db_engine_specs/test_hive.py +++ b/tests/unit_tests/db_engine_specs/test_hive.py @@ -20,8 +20,11 @@ from datetime import datetime from typing import Optional import pytest +from pytest_mock import MockerFixture +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.url import make_url +from superset.sql.parse import Table from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm # noqa: F401 @@ -59,3 +62,40 @@ def test_get_schema_from_engine_params() -> None: ) == "default" ) + + +def test_select_star(mocker: MockerFixture) -> None: + """ + Test the ``select_star`` method. + """ + from superset.db_engine_specs.hive import HiveEngineSpec + + database = mocker.MagicMock() + engine = mocker.MagicMock() + + def quote_table(table: Table, dialect: Dialect) -> str: + return ".".join( + part for part in (table.catalog, table.schema, table.table) if part + ) + + mocker.patch.object(HiveEngineSpec, "quote_table", quote_table) + + HiveEngineSpec.select_star( + database=database, + table=Table("my_table", "my_schema", "my_catalog"), + engine=engine, + limit=100, + show_cols=False, + indent=True, + latest_partition=False, + cols=None, + ) + + query = database.compile_sqla_query.mock_calls[0][1][0] + assert ( + str(query) + == """ +SELECT * \nFROM my_schema.my_table + LIMIT :param_1 + """.strip() + ) diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index b93d70f5f76..290aa3575c9 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -22,10 +22,12 @@ import pytest from pytest_mock import MockerFixture from sqlalchemy import column, types from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.url import make_url from superset.db_engine_specs.postgres import PostgresEngineSpec as spec # noqa: N813 from superset.exceptions import SupersetSecurityException +from superset.sql.parse import Table from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, @@ -243,3 +245,38 @@ def test_timegrain_expressions(time_grain: str, expected_result: str) -> None: spec.get_timestamp_expr(col=column("col"), pdf=None, time_grain=time_grain) ) assert actual == expected_result + + +def test_select_star(mocker: MockerFixture) -> None: + """ + Test the ``select_star`` method. + """ + database = mocker.MagicMock() + engine = mocker.MagicMock() + + def quote_table(table: Table, dialect: Dialect) -> str: + return ".".join( + part for part in (table.catalog, table.schema, table.table) if part + ) + + mocker.patch.object(spec, "quote_table", quote_table) + + spec.select_star( + database=database, + table=Table("my_table", "my_schema", "my_catalog"), + engine=engine, + limit=100, + show_cols=False, + indent=True, + latest_partition=False, + cols=None, + ) + + query = database.compile_sqla_query.mock_calls[0][1][0] + assert ( + str(query) + == """ +SELECT * \nFROM my_schema.my_table + LIMIT :param_1 + """.strip() + ) diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index 9cba2653a52..857bc19b9be 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -21,7 +21,9 @@ from unittest import mock import pytest import pytz from pyhive.sqlalchemy_presto import PrestoDialect +from pytest_mock import MockerFixture from sqlalchemy import column, sql, text, types +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.url import make_url from superset.sql.parse import Table @@ -303,3 +305,40 @@ def test_timegrain_expressions(time_grain: str, expected_result: str) -> None: spec.get_timestamp_expr(col=column("col"), pdf=None, time_grain=time_grain) ) assert actual == expected_result + + +def test_select_star(mocker: MockerFixture) -> None: + """ + Test the ``select_star`` method. + """ + from superset.db_engine_specs.presto import PrestoEngineSpec as spec # noqa: N813 + + database = mocker.MagicMock() + engine = mocker.MagicMock() + + def quote_table(table: Table, dialect: Dialect) -> str: + return ".".join( + part for part in (table.catalog, table.schema, table.table) if part + ) + + mocker.patch.object(spec, "quote_table", quote_table) + + spec.select_star( + database=database, + table=Table("my_table", "my_schema", "my_catalog"), + engine=engine, + limit=100, + show_cols=False, + indent=True, + latest_partition=False, + cols=None, + ) + + query = database.compile_sqla_query.mock_calls[0][1][0] + assert ( + str(query) + == """ +SELECT * \nFROM my_catalog.my_schema.my_table + LIMIT :param_1 + """.strip() + )