Compare commits

...

7 Commits

Author SHA1 Message Date
Beto Dealmeida
1dfe73d19c Fix tests 2025-08-26 18:10:46 -04:00
Beto Dealmeida
bbda5e2008 Fix tests 2025-08-26 16:22:44 -04:00
Beto Dealmeida
53999c12dd Use Result instead 2025-08-26 12:49:27 -04:00
Beto Dealmeida
f554036d29 Fix tests 2025-08-26 11:06:54 -04:00
Beto Dealmeida
33e7932491 More methods 2025-08-25 18:15:43 -04:00
Beto Dealmeida
92b02d993b More methods 2025-08-25 17:40:31 -04:00
Beto Dealmeida
72ba972e42 chore: standardize DB engine spec query execution 2025-08-25 17:31:15 -04:00
17 changed files with 398 additions and 140 deletions

View File

@@ -49,6 +49,7 @@ from flask_babel import gettext as __, lazy_gettext as _
from marshmallow import fields, Schema from marshmallow import fields, Schema
from marshmallow.validate import Range from marshmallow.validate import Range
from sqlalchemy import column, select, types from sqlalchemy import column, select, types
from sqlalchemy.engine import Result
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import Compiled, Dialect from sqlalchemy.engine.interfaces import Compiled, Dialect
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
@@ -1670,14 +1671,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def estimate_statement_cost( def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any cls, database: Database, statement: str
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Generate a SQL query that estimates the cost of a given statement. Generate a SQL query that estimates the cost of a given statement.
:param database: A Database object :param database: A Database object
:param statement: A single SQL statement :param statement: A single SQL statement
:param cursor: Cursor instance
:return: Dictionary with different costs :return: Dictionary with different costs
""" """
raise Exception( # pylint: disable=broad-exception-raised raise Exception( # pylint: disable=broad-exception-raised
@@ -1738,20 +1738,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
parsed_script = SQLScript(sql, engine=cls.engine) parsed_script = SQLScript(sql, engine=cls.engine)
with database.get_raw_connection( return [
catalog=catalog, cls.estimate_statement_cost(
schema=schema, database,
source=source, cls.process_statement(statement, database),
) as conn: )
cursor = conn.cursor() for statement in parsed_script.statements
return [ ]
cls.estimate_statement_cost(
database,
cls.process_statement(statement, database),
cursor,
)
for statement in parsed_script.statements
]
@classmethod @classmethod
def impersonate_user( def impersonate_user(
@@ -1854,6 +1847,42 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls.start_oauth2_dance(database) cls.start_oauth2_dance(database)
raise cls.get_dbapi_mapped_exception(ex) from ex raise cls.get_dbapi_mapped_exception(ex) from ex
@classmethod
def execute_metadata_query(
cls,
database: Database,
query: str,
catalog: str | None = None,
schema: str | None = None,
) -> Result:
"""
Standardized method for executing metadata queries.
This method provides a unified interface for all metadata query operations
across different database engines using SQLAlchemy connections.
For single-row results, add "LIMIT 1" to your query rather than using
separate fetch parameters.
:param database: Database instance
:param query: SQL query to execute for metadata
:param catalog: Optional catalog/database name
:param schema: Optional schema name
:return: SQLAlchemy Result object with methods like:
- result.fetchall() -> list[Row]: Get all rows
- result.fetchone() -> Row | None: Get single row
- result.scalar() -> Any: Get single value
- result.mappings() -> mappings for dict-like access
"""
with cls.get_engine(
database,
catalog=catalog,
schema=schema,
source=utils.QuerySource.METADATA,
) as engine:
with engine.connect() as conn:
return conn.execute(text(query))
@classmethod @classmethod
def needs_oauth2(cls, ex: Exception) -> bool: def needs_oauth2(cls, ex: Exception) -> bool:
""" """

View File

@@ -331,9 +331,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
) )
# Build the query # Build the query
query = select( query = (
func.max(partitions_table.c.partition_id).label("max_partition_id") select(func.max(partitions_table.c.partition_id).label("max_partition_id"))
).where(partitions_table.c.table_name == table.table) .where(partitions_table.c.table_name == table.table)
.limit(1)
)
# Compile to BigQuery SQL # Compile to BigQuery SQL
compiled_query = query.compile( compiled_query = query.compile(
@@ -342,15 +344,13 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
) )
# Run the query and handle result # Run the query and handle result
with database.get_raw_connection( result = cls.execute_metadata_query(
database,
str(compiled_query),
catalog=table.catalog, catalog=table.catalog,
schema=table.schema, schema=table.schema,
) as conn: )
cursor = conn.cursor() return result.scalar()
cursor.execute(str(compiled_query))
if row := cursor.fetchone():
return row[0]
return None
@classmethod @classmethod
def get_time_partition_column( def get_time_partition_column(

View File

@@ -503,12 +503,13 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
if default_catalog := connect_args.get("catalog"): if default_catalog := connect_args.get("catalog"):
return default_catalog return default_catalog
with database.get_sqla_engine() as engine: result = cls.execute_metadata_query(database, "SHOW CATALOGS")
catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")} catalogs = {catalog for (catalog,) in result}
if len(catalogs) == 1: if len(catalogs) == 1:
return catalogs.pop() return catalogs.pop()
return engine.execute("SELECT current_catalog()").scalar() result = cls.execute_metadata_query(database, "SELECT current_catalog()")
return result.scalar()
@classmethod @classmethod
def get_prequeries( def get_prequeries(
@@ -532,7 +533,8 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
database: Database, database: Database,
inspector: Inspector, inspector: Inspector,
) -> set[str]: ) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} results = cls.execute_metadata_query(database, "SHOW CATALOGS")
return {catalog for (catalog,) in results}
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec): class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
@@ -624,7 +626,8 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
database: Database, database: Database,
inspector: Inspector, inspector: Inspector,
) -> set[str]: ) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} results = cls.execute_metadata_query(database, "SHOW CATALOGS")
return {catalog for (catalog,) in results}
@classmethod @classmethod
def adjust_engine_params( def adjust_engine_params(

View File

@@ -297,8 +297,8 @@ class DorisEngineSpec(MySQLEngineSpec):
CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
We need to extract just the CatalogName column. We need to extract just the CatalogName column.
""" """
result = inspector.bind.execute("SHOW CATALOGS") results = cls.execute_metadata_query(database, "SHOW CATALOGS")
return {row.CatalogName for row in result} return {row.CatalogName for row in results}
@classmethod @classmethod
def get_schema_from_engine_params( def get_schema_from_engine_params(

View File

@@ -388,9 +388,8 @@ class MotherDuckEngineSpec(DuckDBEngineSpec):
database: Database, database: Database,
inspector: Inspector, inspector: Inspector,
) -> set[str]: ) -> set[str]:
return { results = cls.execute_metadata_query(
catalog database,
for (catalog,) in inspector.bind.execute( "SELECT alias FROM MD_ALL_DATABASES() WHERE is_attached;",
"SELECT alias FROM MD_ALL_DATABASES() WHERE is_attached;" )
) return {catalog for (catalog,) in results}
}

View File

@@ -154,15 +154,16 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
database: Database, database: Database,
table: Table, table: Table,
) -> dict[str, Any]: ) -> dict[str, Any]:
with database.get_raw_connection( result = cls.execute_metadata_query(
database,
f'SELECT GET_METADATA("{table.table}") LIMIT 1',
catalog=table.catalog, catalog=table.catalog,
schema=table.schema, schema=table.schema,
) as conn: )
cursor = conn.cursor() results_list = result.fetchall()
cursor.execute(f'SELECT GET_METADATA("{table.table}")') results = results_list[0][0] if results_list else None
results = cursor.fetchone()[0]
try: try:
metadata = json.loads(results) metadata = json.loads(results) if results else {}
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
metadata = {} metadata = {}

View File

@@ -614,8 +614,9 @@ class HiveEngineSpec(PrestoEngineSpec):
if schema: if schema:
sql += f" IN `{schema}`" sql += f" IN `{schema}`"
with database.get_raw_connection(schema=schema) as conn: result = cls.execute_metadata_query(
cursor = conn.cursor() database,
cursor.execute(sql) sql,
results = cursor.fetchall() schema=schema,
return {row[0] for row in results} )
return {row[0] for row in result}

View File

@@ -354,19 +354,18 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
@classmethod @classmethod
def estimate_statement_cost( def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any cls, database: Database, statement: str
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Run a SQL query that estimates the cost of a given statement. Run a SQL query that estimates the cost of a given statement.
:param database: A Database object :param database: A Database object
:param statement: A single SQL statement :param statement: A single SQL statement
:param cursor: Cursor instance :return: Cost estimate dictionary
:return: JSON response from Trino
""" """
sql = f"EXPLAIN {statement}" sql = f"EXPLAIN {statement} LIMIT 1"
cursor.execute(sql) results = cls.execute_metadata_query(database, sql)
rows = results.fetchall()
result = cursor.fetchone()[0] result = rows[0][0] if rows else ""
match = re.search(r"cost=([\d\.]+)\.\.([\d\.]+)", result) match = re.search(r"cost=([\d\.]+)\.\.([\d\.]+)", result)
if match: if match:
return { return {
@@ -393,15 +392,14 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
In Postgres, a catalog is called a "database". In Postgres, a catalog is called a "database".
""" """
return { results = cls.execute_metadata_query(
catalog database,
for (catalog,) in inspector.bind.execute( """
"""
SELECT datname FROM pg_database SELECT datname FROM pg_database
WHERE datistemplate = false; WHERE datistemplate = false;
""" """,
) )
} return {catalog for (catalog,) in results}
@classmethod @classmethod
def get_table_names( def get_table_names(

View File

@@ -321,7 +321,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
""" """
Get all catalogs. Get all catalogs.
""" """
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} result = cls.execute_metadata_query(database, "SHOW CATALOGS")
return {catalog for (catalog,) in result}
@classmethod @classmethod
def adjust_engine_params( def adjust_engine_params(
@@ -373,17 +374,16 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod @classmethod
def estimate_statement_cost( def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any cls, database: Database, statement: str
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Run a SQL query that estimates the cost of a given statement. Run a SQL query that estimates the cost of a given statement.
:param database: A Database object :param database: A Database object
:param statement: A single SQL statement :param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino :return: JSON response from Trino
""" """
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}" sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement} LIMIT 1"
cursor.execute(sql) results = cls.execute_metadata_query(database, sql)
# the output from Trino is a single column and a single row containing # the output from Trino is a single column and a single row containing
# JSON: # JSON:
@@ -398,7 +398,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
# "networkCost" : 3.41425774958E11 # "networkCost" : 3.41425774958E11
# } # }
# } # }
result = json.loads(cursor.fetchone()[0]) result = json.loads(results[0][0]) if results else {}
return result return result
@classmethod @classmethod
@@ -1037,7 +1037,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
AND table_type = 'VIEW' AND table_type = 'VIEW'
""" """
).strip() ).strip()
params = {"schema": schema} results = inspector.bind.execute(sql, {"schema": schema}).fetchall()
else: else:
sql = dedent( sql = dedent(
""" """
@@ -1045,13 +1045,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
WHERE table_type = 'VIEW' WHERE table_type = 'VIEW'
""" """
).strip() ).strip()
params = {} results = inspector.bind.execute(sql).fetchall()
with database.get_raw_connection(schema=schema) as conn: return {row[0] for row in results}
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()
return {row[0] for row in results}
@classmethod @classmethod
def _is_column_name_quoted(cls, column_name: str) -> bool: def _is_column_name_quoted(cls, column_name: str) -> bool:
@@ -1299,16 +1295,12 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from pyhive.exc import DatabaseError from pyhive.exc import DatabaseError
with database.get_raw_connection(schema=schema) as conn: sql = f"SHOW CREATE VIEW {schema}.{table} LIMIT 1"
cursor = conn.cursor() try:
sql = f"SHOW CREATE VIEW {schema}.{table}" results = cls.execute_metadata_query(database, sql, schema=schema)
try: return results[0][0] if results else None
cls.execute(cursor, sql, database) except DatabaseError: # not a VIEW
rows = cls.fetch_data(cursor, 1) return None
return rows[0][0]
except DatabaseError: # not a VIEW
return None
@classmethod @classmethod
def get_tracking_url(cls, cursor: Cursor) -> str | None: def get_tracking_url(cls, cursor: Cursor) -> str | None:

View File

@@ -209,12 +209,11 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
In Snowflake, a catalog is called a "database". In Snowflake, a catalog is called a "database".
""" """
return { results = cls.execute_metadata_query(
catalog database,
for (catalog,) in inspector.bind.execute( "SELECT DATABASE_NAME from information_schema.databases",
"SELECT DATABASE_NAME from information_schema.databases" )
) return {catalog for (catalog,) in results}
}
@classmethod @classmethod
def epoch_to_dttm(cls) -> str: def epoch_to_dttm(cls) -> str:

View File

@@ -320,6 +320,7 @@ class QuerySource(Enum):
CHART = 0 CHART = 0
DASHBOARD = 1 DASHBOARD = 1
SQL_LAB = 2 SQL_LAB = 2
METADATA = 3
class QueryStatus(StrEnum): class QueryStatus(StrEnum):

View File

@@ -152,13 +152,22 @@ class TestPostgresDbEngineSpec(SupersetTestCase):
""" """
database = mock.Mock() database = mock.Mock()
cursor = mock.Mock()
cursor.fetchone.return_value = (
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
)
sql = "SELECT * FROM birth_names" sql = "SELECT * FROM birth_names"
results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
# Mock the execute_metadata_query method to return expected results
with mock.patch.object(
PostgresEngineSpec, "execute_metadata_query"
) as mock_execute:
mock_result = mock.Mock()
expected_results = [
("Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",)
]
mock_result.fetchall.return_value = expected_results
mock_execute.return_value = mock_result
results = PostgresEngineSpec.estimate_statement_cost(database, sql)
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91} assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
mock_execute.assert_called_once_with(database, f"EXPLAIN {sql} LIMIT 1")
def test_estimate_statement_invalid_syntax(self): def test_estimate_statement_invalid_syntax(self):
""" """
@@ -167,17 +176,21 @@ class TestPostgresDbEngineSpec(SupersetTestCase):
from psycopg2 import errors from psycopg2 import errors
database = mock.Mock() database = mock.Mock()
cursor = mock.Mock()
cursor.execute.side_effect = errors.SyntaxError(
"""
syntax error at or near "EXPLAIN"
LINE 1: EXPLAIN DROP TABLE birth_names
^
"""
)
sql = "DROP TABLE birth_names" sql = "DROP TABLE birth_names"
with self.assertRaises(errors.SyntaxError): # noqa: PT027
PostgresEngineSpec.estimate_statement_cost(database, sql, cursor) # Mock the execute_metadata_query method to raise the expected exception
with mock.patch.object(
PostgresEngineSpec, "execute_metadata_query"
) as mock_execute:
mock_execute.side_effect = errors.SyntaxError(
"""
syntax error at or near "EXPLAIN"
LINE 1: EXPLAIN DROP TABLE birth_names
^
"""
)
with self.assertRaises(errors.SyntaxError): # noqa: PT027
PostgresEngineSpec.estimate_statement_cost(database, sql)
def test_query_cost_formatter_example_costs(self): def test_query_cost_formatter_example_costs(self):
""" """

View File

@@ -931,26 +931,34 @@ class TestPrestoDbEngineSpec(SupersetTestCase):
def test_estimate_statement_cost(self): def test_estimate_statement_cost(self):
mock_database = mock.MagicMock() mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
estimate_json = {"a": "b"} estimate_json = {"a": "b"}
mock_cursor.fetchone.return_value = [ sql = "SELECT * FROM brth_names"
'{"a": "b"}',
] # Mock the execute_metadata_query method to return expected JSON results
result = PrestoEngineSpec.estimate_statement_cost( with mock.patch.object(
mock_database, PrestoEngineSpec, "execute_metadata_query"
"SELECT * FROM brth_names", ) as mock_execute:
mock_cursor, mock_result = mock.Mock()
) mock_result.scalar.return_value = '{"a": "b"}'
mock_execute.return_value = mock_result
result = PrestoEngineSpec.estimate_statement_cost(mock_database, sql)
assert result == estimate_json assert result == estimate_json
mock_execute.assert_called_once_with(
mock_database, f"EXPLAIN (TYPE IO, FORMAT JSON) {sql} LIMIT 1"
)
def test_estimate_statement_cost_invalid_syntax(self): def test_estimate_statement_cost_invalid_syntax(self):
mock_database = mock.MagicMock() mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock() sql = "DROP TABLE brth_names"
mock_cursor.execute.side_effect = Exception()
with self.assertRaises(Exception): # noqa: B017, PT027 # Mock the execute_metadata_query method to raise an exception
PrestoEngineSpec.estimate_statement_cost( with mock.patch.object(
mock_database, "DROP TABLE brth_names", mock_cursor PrestoEngineSpec, "execute_metadata_query"
) ) as mock_execute:
mock_execute.side_effect = Exception("Invalid syntax")
with self.assertRaises(Exception): # noqa: B017, PT027
PrestoEngineSpec.estimate_statement_cost(mock_database, sql)
def test_get_create_view(self): def test_get_create_view(self):
mock_execute = mock.MagicMock() mock_execute = mock.MagicMock()
@@ -1207,10 +1215,10 @@ def test_get_catalog_names(app_context: AppContext) -> None:
return return
with database.get_inspector() as inspector: with database.get_inspector() as inspector:
assert PrestoEngineSpec.get_catalog_names(database, inspector) == [ assert PrestoEngineSpec.get_catalog_names(database, inspector) == {
"jmx", "jmx",
"memory", "memory",
"system", "system",
"tpcds", "tpcds",
"tpch", "tpch",
] }

View File

@@ -44,10 +44,10 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
database.database_name = "my_db" database.database_name = "my_db"
database.db_engine_spec.__name__ = "test_engine" database.db_engine_spec.__name__ = "test_engine"
database.db_engine_spec.supports_catalog = True database.db_engine_spec.supports_catalog = True
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"] database.get_all_catalog_names.return_value = {"catalog1", "catalog2"}
database.get_all_schema_names.side_effect = [ database.get_all_schema_names.side_effect = [
["schema1", "schema2"], {"schema1", "schema2"},
["schema3", "schema4"], {"schema3", "schema4"},
] ]
database.get_default_catalog.return_value = "catalog2" database.get_default_catalog.return_value = "catalog2"
@@ -63,7 +63,7 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
database.database_name = "my_db" database.database_name = "my_db"
database.db_engine_spec.__name__ = "test_engine" database.db_engine_spec.__name__ = "test_engine"
database.db_engine_spec.supports_catalog = False database.db_engine_spec.supports_catalog = False
database.get_all_schema_names.return_value = ["schema1", "schema2"] database.get_all_schema_names.return_value = {"schema1", "schema2"}
database.is_oauth2_enabled.return_value = False database.is_oauth2_enabled.return_value = False
database.db_engine_spec.needs_oauth2.return_value = False database.db_engine_spec.needs_oauth2.return_value = False

View File

@@ -69,23 +69,24 @@ def test_sync_permissions_command_sync_mode(
add_pvm_mock.assert_has_calls( add_pvm_mock.assert_has_calls(
[ [
mocker.call( mocker.call(
db.session, security_manager, "catalog_access", "[my_db].[catalog2]" db.session, security_manager, "catalog_access", "[my_db].[catalog1]"
), ),
mocker.call( mocker.call(
db.session, db.session,
security_manager, security_manager,
"schema_access", "schema_access",
"[my_db].[catalog2].[schema3]", "[my_db].[catalog1].[schema3]",
), ),
mocker.call( mocker.call(
db.session, db.session,
security_manager, security_manager,
"schema_access", "schema_access",
"[my_db].[catalog2].[schema4]", "[my_db].[catalog1].[schema4]",
), ),
] ],
any_order=True,
) )
mock_refresh_schemas.assert_called_once_with("catalog1", ["schema1", "schema2"]) mock_refresh_schemas.assert_called_once_with("catalog2", {"schema1", "schema2"})
mock_rename_db_perm.assert_not_called() mock_rename_db_perm.assert_not_called()
@@ -246,7 +247,7 @@ def test_sync_permissions_command_get_catalogs(database_with_catalog: MagicMock)
Test the ``_get_catalog_names`` method. Test the ``_get_catalog_names`` method.
""" """
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_catalog_names() == ["catalog1", "catalog2"] assert cmmd._get_catalog_names() == {"catalog1", "catalog2"}
def test_sync_permissions_command_get_default_catalog(database_with_catalog: MagicMock): def test_sync_permissions_command_get_default_catalog(database_with_catalog: MagicMock):
@@ -263,7 +264,7 @@ def test_sync_permissions_command_get_default_catalog(database_with_catalog: Mag
database_with_catalog.allow_multi_catalog = True database_with_catalog.allow_multi_catalog = True
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_catalog_names() == ["catalog1", "catalog2"] assert cmmd._get_catalog_names() == {"catalog1", "catalog2"}
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -295,8 +296,8 @@ def test_sync_permissions_command_get_schemas(database_with_catalog: MagicMock):
Test the ``_get_schema_names`` method. Test the ``_get_schema_names`` method.
""" """
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_schema_names("catalog1") == ["schema1", "schema2"] assert cmmd._get_schema_names("catalog1") == {"schema1", "schema2"}
assert cmmd._get_schema_names("catalog2") == ["schema3", "schema4"] assert cmmd._get_schema_names("catalog2") == {"schema3", "schema4"}
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@@ -0,0 +1,195 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from sqlalchemy.exc import SQLAlchemyError
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils.core import QuerySource
from tests.integration_tests.base_tests import SupersetTestCase
class TestBaseEngineSpecMetadata(SupersetTestCase):
@mock.patch.object(BaseEngineSpec, "get_engine")
def test_execute_metadata_query_basic(self, mock_get_engine):
"""Test basic metadata query execution"""
database = mock.Mock()
mock_engine = mock.Mock()
mock_connection = mock.Mock()
mock_result = mock.Mock()
# Setup mock chain
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
mock_engine.connect.return_value.__enter__ = mock.Mock(
return_value=mock_connection
)
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
expected_results = [("catalog1",), ("catalog2",)]
mock_result.fetchall.return_value = expected_results
mock_result.__iter__ = lambda self: iter(expected_results)
mock_connection.execute.return_value = mock_result
# Execute the method
query = "SHOW CATALOGS"
results = BaseEngineSpec.execute_metadata_query(database, query)
# Verify results - now we get the Result object itself
assert results == mock_result
# Test that we can iterate over it like the real usage
assert list(results) == expected_results
# Verify call chain
mock_get_engine.assert_called_once_with(
database, catalog=None, schema=None, source=QuerySource.METADATA
)
mock_connection.execute.assert_called_once()
executed_query = mock_connection.execute.call_args[0][0]
assert str(executed_query) == query
@mock.patch.object(BaseEngineSpec, "get_engine")
def test_execute_metadata_query_with_catalog_schema(self, mock_get_engine):
"""Test metadata query with catalog and schema parameters"""
database = mock.Mock()
mock_engine = mock.Mock()
mock_connection = mock.Mock()
mock_result = mock.Mock()
# Setup mock chain
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
mock_engine.connect.return_value.__enter__ = mock.Mock(
return_value=mock_connection
)
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
expected_results = [("table1",), ("table2",)]
mock_result.fetchall.return_value = expected_results
mock_result.__iter__ = lambda self: iter(expected_results)
mock_connection.execute.return_value = mock_result
# Execute with catalog and schema
query = "SHOW TABLES"
catalog = "my_catalog"
schema = "my_schema"
results = BaseEngineSpec.execute_metadata_query(
database, query, catalog=catalog, schema=schema
)
# Verify results - now we get the Result object itself
assert results == mock_result
# Test that we can iterate over it like the real usage
assert list(results) == expected_results
# Verify catalog and schema were passed to get_engine
mock_get_engine.assert_called_once_with(
database, catalog=catalog, schema=schema, source=QuerySource.METADATA
)
@mock.patch.object(BaseEngineSpec, "get_engine")
def test_execute_metadata_query_uses_correct_query_source(self, mock_get_engine):
"""Test that QuerySource.METADATA is used correctly"""
database = mock.Mock()
mock_engine = mock.Mock()
mock_connection = mock.Mock()
mock_result = mock.Mock()
# Setup mock chain
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
mock_engine.connect.return_value.__enter__ = mock.Mock(
return_value=mock_connection
)
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
empty_results = []
mock_result.fetchall.return_value = empty_results
mock_result.__iter__ = lambda self: iter(empty_results)
mock_connection.execute.return_value = mock_result
# Execute the method
BaseEngineSpec.execute_metadata_query(database, "SELECT 1")
# Verify QuerySource.METADATA was used
mock_get_engine.assert_called_once_with(
database, catalog=None, schema=None, source=QuerySource.METADATA
)
@mock.patch.object(BaseEngineSpec, "get_engine")
def test_execute_metadata_query_handles_sql_error(self, mock_get_engine):
"""Test proper exception handling for SQL errors"""
database = mock.Mock()
mock_engine = mock.Mock()
mock_connection = mock.Mock()
# Setup mock chain
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
mock_engine.connect.return_value.__enter__ = mock.Mock(
return_value=mock_connection
)
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
# Mock connection to raise SQLAlchemyError
mock_connection.execute.side_effect = SQLAlchemyError("Database error")
# Execute and verify exception is propagated
with pytest.raises(SQLAlchemyError):
BaseEngineSpec.execute_metadata_query(database, "INVALID QUERY")
@mock.patch.object(BaseEngineSpec, "get_engine")
def test_execute_metadata_query_empty_results(self, mock_get_engine):
"""Test handling of queries that return no results"""
database = mock.Mock()
mock_engine = mock.Mock()
mock_connection = mock.Mock()
mock_result = mock.Mock()
# Setup mock chain
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
mock_engine.connect.return_value.__enter__ = mock.Mock(
return_value=mock_connection
)
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
# Empty results
empty_results = []
mock_result.fetchall.return_value = empty_results
mock_result.__iter__ = lambda self: iter(empty_results)
mock_connection.execute.return_value = mock_result
# Execute the method
results = BaseEngineSpec.execute_metadata_query(database, "SELECT 1 WHERE 1=0")
# Verify empty results are handled correctly - now we get the Result object
assert results == mock_result
# Test that iterating over it gives empty results like the real usage
assert list(results) == []
def test_execute_metadata_query_query_source_enum_value(self):
"""Test QuerySource.METADATA enum has correct value"""
# This is a simple enum test that doesn't require complex mocking
assert QuerySource.METADATA.value == 3
assert QuerySource.METADATA.name == "METADATA"

View File

@@ -1122,3 +1122,21 @@ def test_get_stacktrace():
except Exception: except Exception:
stacktrace = get_stacktrace() stacktrace = get_stacktrace()
assert stacktrace is None assert stacktrace is None
def test_query_source_metadata_enum():
"""Test that QuerySource.METADATA enum value exists and has correct value"""
assert hasattr(QuerySource, "METADATA")
assert QuerySource.METADATA.value == 3
assert QuerySource.METADATA.name == "METADATA"
# Verify all expected QuerySource values
expected_sources = {
QuerySource.CHART: 0,
QuerySource.DASHBOARD: 1,
QuerySource.SQL_LAB: 2,
QuerySource.METADATA: 3,
}
for source, expected_value in expected_sources.items():
assert source.value == expected_value