Compare commits

...

7 Commits

Author SHA1 Message Date
Beto Dealmeida
1dfe73d19c Fix tests 2025-08-26 18:10:46 -04:00
Beto Dealmeida
bbda5e2008 Fix tests 2025-08-26 16:22:44 -04:00
Beto Dealmeida
53999c12dd Use Result instead 2025-08-26 12:49:27 -04:00
Beto Dealmeida
f554036d29 Fix tests 2025-08-26 11:06:54 -04:00
Beto Dealmeida
33e7932491 More methods 2025-08-25 18:15:43 -04:00
Beto Dealmeida
92b02d993b More methods 2025-08-25 17:40:31 -04:00
Beto Dealmeida
72ba972e42 chore: standardize DB engine spec query execution 2025-08-25 17:31:15 -04:00
17 changed files with 398 additions and 140 deletions

View File

@@ -49,6 +49,7 @@ from flask_babel import gettext as __, lazy_gettext as _
from marshmallow import fields, Schema
from marshmallow.validate import Range
from sqlalchemy import column, select, types
from sqlalchemy.engine import Result
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import Compiled, Dialect
from sqlalchemy.engine.reflection import Inspector
@@ -1670,14 +1671,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
cls, database: Database, statement: str
) -> dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: Dictionary with different costs
"""
raise Exception( # pylint: disable=broad-exception-raised
@@ -1738,20 +1738,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
parsed_script = SQLScript(sql, engine=cls.engine)
with database.get_raw_connection(
catalog=catalog,
schema=schema,
source=source,
) as conn:
cursor = conn.cursor()
return [
cls.estimate_statement_cost(
database,
cls.process_statement(statement, database),
cursor,
)
for statement in parsed_script.statements
]
return [
cls.estimate_statement_cost(
database,
cls.process_statement(statement, database),
)
for statement in parsed_script.statements
]
@classmethod
def impersonate_user(
@@ -1854,6 +1847,42 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls.start_oauth2_dance(database)
raise cls.get_dbapi_mapped_exception(ex) from ex
@classmethod
def execute_metadata_query(
cls,
database: Database,
query: str,
catalog: str | None = None,
schema: str | None = None,
) -> Result:
"""
Standardized method for executing metadata queries.
This method provides a unified interface for all metadata query operations
across different database engines using SQLAlchemy connections.
For single-row results, add "LIMIT 1" to your query rather than using
separate fetch parameters.
:param database: Database instance
:param query: SQL query to execute for metadata
:param catalog: Optional catalog/database name
:param schema: Optional schema name
:return: SQLAlchemy Result object with methods like:
- result.fetchall() -> list[Row]: Get all rows
- result.fetchone() -> Row | None: Get single row
- result.scalar() -> Any: Get single value
- result.mappings() -> mappings for dict-like access
"""
with cls.get_engine(
database,
catalog=catalog,
schema=schema,
source=utils.QuerySource.METADATA,
) as engine:
with engine.connect() as conn:
return conn.execute(text(query))
@classmethod
def needs_oauth2(cls, ex: Exception) -> bool:
"""

View File

@@ -331,9 +331,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
)
# Build the query
query = select(
func.max(partitions_table.c.partition_id).label("max_partition_id")
).where(partitions_table.c.table_name == table.table)
query = (
select(func.max(partitions_table.c.partition_id).label("max_partition_id"))
.where(partitions_table.c.table_name == table.table)
.limit(1)
)
# Compile to BigQuery SQL
compiled_query = query.compile(
@@ -342,15 +344,13 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
)
# Run the query and handle result
with database.get_raw_connection(
result = cls.execute_metadata_query(
database,
str(compiled_query),
catalog=table.catalog,
schema=table.schema,
) as conn:
cursor = conn.cursor()
cursor.execute(str(compiled_query))
if row := cursor.fetchone():
return row[0]
return None
)
return result.scalar()
@classmethod
def get_time_partition_column(

View File

@@ -503,12 +503,13 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
if default_catalog := connect_args.get("catalog"):
return default_catalog
with database.get_sqla_engine() as engine:
catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")}
if len(catalogs) == 1:
return catalogs.pop()
result = cls.execute_metadata_query(database, "SHOW CATALOGS")
catalogs = {catalog for (catalog,) in result}
if len(catalogs) == 1:
return catalogs.pop()
return engine.execute("SELECT current_catalog()").scalar()
result = cls.execute_metadata_query(database, "SELECT current_catalog()")
return result.scalar()
@classmethod
def get_prequeries(
@@ -532,7 +533,8 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
database: Database,
inspector: Inspector,
) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
results = cls.execute_metadata_query(database, "SHOW CATALOGS")
return {catalog for (catalog,) in results}
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
@@ -624,7 +626,8 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
database: Database,
inspector: Inspector,
) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
results = cls.execute_metadata_query(database, "SHOW CATALOGS")
return {catalog for (catalog,) in results}
@classmethod
def adjust_engine_params(

View File

@@ -297,8 +297,8 @@ class DorisEngineSpec(MySQLEngineSpec):
CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
We need to extract just the CatalogName column.
"""
result = inspector.bind.execute("SHOW CATALOGS")
return {row.CatalogName for row in result}
results = cls.execute_metadata_query(database, "SHOW CATALOGS")
return {row.CatalogName for row in results}
@classmethod
def get_schema_from_engine_params(

View File

@@ -388,9 +388,8 @@ class MotherDuckEngineSpec(DuckDBEngineSpec):
database: Database,
inspector: Inspector,
) -> set[str]:
return {
catalog
for (catalog,) in inspector.bind.execute(
"SELECT alias FROM MD_ALL_DATABASES() WHERE is_attached;"
)
}
results = cls.execute_metadata_query(
database,
"SELECT alias FROM MD_ALL_DATABASES() WHERE is_attached;",
)
return {catalog for (catalog,) in results}

View File

@@ -154,15 +154,16 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
database: Database,
table: Table,
) -> dict[str, Any]:
with database.get_raw_connection(
result = cls.execute_metadata_query(
database,
f'SELECT GET_METADATA("{table.table}") LIMIT 1',
catalog=table.catalog,
schema=table.schema,
) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table.table}")')
results = cursor.fetchone()[0]
)
results_list = result.fetchall()
results = results_list[0][0] if results_list else None
try:
metadata = json.loads(results)
metadata = json.loads(results) if results else {}
except Exception: # pylint: disable=broad-except
metadata = {}

View File

@@ -614,8 +614,9 @@ class HiveEngineSpec(PrestoEngineSpec):
if schema:
sql += f" IN `{schema}`"
with database.get_raw_connection(schema=schema) as conn:
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
return {row[0] for row in results}
result = cls.execute_metadata_query(
database,
sql,
schema=schema,
)
return {row[0] for row in result}

View File

@@ -354,19 +354,18 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
@classmethod
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
cls, database: Database, statement: str
) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
:return: Cost estimate dictionary
"""
sql = f"EXPLAIN {statement}"
cursor.execute(sql)
result = cursor.fetchone()[0]
sql = f"EXPLAIN {statement} LIMIT 1"
results = cls.execute_metadata_query(database, sql)
rows = results.fetchall()
result = rows[0][0] if rows else ""
match = re.search(r"cost=([\d\.]+)\.\.([\d\.]+)", result)
if match:
return {
@@ -393,15 +392,14 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
In Postgres, a catalog is called a "database".
"""
return {
catalog
for (catalog,) in inspector.bind.execute(
"""
results = cls.execute_metadata_query(
database,
"""
SELECT datname FROM pg_database
WHERE datistemplate = false;
"""
)
}
""",
)
return {catalog for (catalog,) in results}
@classmethod
def get_table_names(

View File

@@ -321,7 +321,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
"""
Get all catalogs.
"""
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
result = cls.execute_metadata_query(database, "SHOW CATALOGS")
return {catalog for (catalog,) in result}
@classmethod
def adjust_engine_params(
@@ -373,17 +374,16 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
cls, database: Database, statement: str
) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
"""
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}"
cursor.execute(sql)
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement} LIMIT 1"
results = cls.execute_metadata_query(database, sql)
# the output from Trino is a single column and a single row containing
# JSON:
@@ -398,7 +398,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
# "networkCost" : 3.41425774958E11
# }
# }
result = json.loads(cursor.fetchone()[0])
result = json.loads(results[0][0]) if results else {}
return result
@classmethod
@@ -1037,7 +1037,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
AND table_type = 'VIEW'
"""
).strip()
params = {"schema": schema}
results = inspector.bind.execute(sql, {"schema": schema}).fetchall()
else:
sql = dedent(
"""
@@ -1045,13 +1045,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
WHERE table_type = 'VIEW'
"""
).strip()
params = {}
results = inspector.bind.execute(sql).fetchall()
with database.get_raw_connection(schema=schema) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()
return {row[0] for row in results}
return {row[0] for row in results}
@classmethod
def _is_column_name_quoted(cls, column_name: str) -> bool:
@@ -1299,16 +1295,12 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# pylint: disable=import-outside-toplevel
from pyhive.exc import DatabaseError
with database.get_raw_connection(schema=schema) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql, database)
rows = cls.fetch_data(cursor, 1)
return rows[0][0]
except DatabaseError: # not a VIEW
return None
sql = f"SHOW CREATE VIEW {schema}.{table} LIMIT 1"
try:
results = cls.execute_metadata_query(database, sql, schema=schema)
return results[0][0] if results else None
except DatabaseError: # not a VIEW
return None
@classmethod
def get_tracking_url(cls, cursor: Cursor) -> str | None:

View File

@@ -209,12 +209,11 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
In Snowflake, a catalog is called a "database".
"""
return {
catalog
for (catalog,) in inspector.bind.execute(
"SELECT DATABASE_NAME from information_schema.databases"
)
}
results = cls.execute_metadata_query(
database,
"SELECT DATABASE_NAME from information_schema.databases",
)
return {catalog for (catalog,) in results}
@classmethod
def epoch_to_dttm(cls) -> str:

View File

@@ -320,6 +320,7 @@ class QuerySource(Enum):
CHART = 0
DASHBOARD = 1
SQL_LAB = 2
METADATA = 3
class QueryStatus(StrEnum):

View File

@@ -152,13 +152,22 @@ class TestPostgresDbEngineSpec(SupersetTestCase):
"""
database = mock.Mock()
cursor = mock.Mock()
cursor.fetchone.return_value = (
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
)
sql = "SELECT * FROM birth_names"
results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
# Mock the execute_metadata_query method to return expected results
with mock.patch.object(
PostgresEngineSpec, "execute_metadata_query"
) as mock_execute:
mock_result = mock.Mock()
expected_results = [
("Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",)
]
mock_result.fetchall.return_value = expected_results
mock_execute.return_value = mock_result
results = PostgresEngineSpec.estimate_statement_cost(database, sql)
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
mock_execute.assert_called_once_with(database, f"EXPLAIN {sql} LIMIT 1")
def test_estimate_statement_invalid_syntax(self):
"""
@@ -167,17 +176,21 @@ class TestPostgresDbEngineSpec(SupersetTestCase):
from psycopg2 import errors
database = mock.Mock()
cursor = mock.Mock()
cursor.execute.side_effect = errors.SyntaxError(
"""
syntax error at or near "EXPLAIN"
LINE 1: EXPLAIN DROP TABLE birth_names
^
"""
)
sql = "DROP TABLE birth_names"
with self.assertRaises(errors.SyntaxError): # noqa: PT027
PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
# Mock the execute_metadata_query method to raise the expected exception
with mock.patch.object(
PostgresEngineSpec, "execute_metadata_query"
) as mock_execute:
mock_execute.side_effect = errors.SyntaxError(
"""
syntax error at or near "EXPLAIN"
LINE 1: EXPLAIN DROP TABLE birth_names
^
"""
)
with self.assertRaises(errors.SyntaxError): # noqa: PT027
PostgresEngineSpec.estimate_statement_cost(database, sql)
def test_query_cost_formatter_example_costs(self):
"""

View File

@@ -931,26 +931,34 @@ class TestPrestoDbEngineSpec(SupersetTestCase):
def test_estimate_statement_cost(self):
mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
estimate_json = {"a": "b"}
mock_cursor.fetchone.return_value = [
'{"a": "b"}',
]
result = PrestoEngineSpec.estimate_statement_cost(
mock_database,
"SELECT * FROM brth_names",
mock_cursor,
)
sql = "SELECT * FROM brth_names"
# Mock the execute_metadata_query method to return expected JSON results
with mock.patch.object(
PrestoEngineSpec, "execute_metadata_query"
) as mock_execute:
mock_result = mock.Mock()
mock_result.scalar.return_value = '{"a": "b"}'
mock_execute.return_value = mock_result
result = PrestoEngineSpec.estimate_statement_cost(mock_database, sql)
assert result == estimate_json
mock_execute.assert_called_once_with(
mock_database, f"EXPLAIN (TYPE IO, FORMAT JSON) {sql} LIMIT 1"
)
def test_estimate_statement_cost_invalid_syntax(self):
mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
mock_cursor.execute.side_effect = Exception()
with self.assertRaises(Exception): # noqa: B017, PT027
PrestoEngineSpec.estimate_statement_cost(
mock_database, "DROP TABLE brth_names", mock_cursor
)
sql = "DROP TABLE brth_names"
# Mock the execute_metadata_query method to raise an exception
with mock.patch.object(
PrestoEngineSpec, "execute_metadata_query"
) as mock_execute:
mock_execute.side_effect = Exception("Invalid syntax")
with self.assertRaises(Exception): # noqa: B017, PT027
PrestoEngineSpec.estimate_statement_cost(mock_database, sql)
def test_get_create_view(self):
mock_execute = mock.MagicMock()
@@ -1207,10 +1215,10 @@ def test_get_catalog_names(app_context: AppContext) -> None:
return
with database.get_inspector() as inspector:
assert PrestoEngineSpec.get_catalog_names(database, inspector) == [
assert PrestoEngineSpec.get_catalog_names(database, inspector) == {
"jmx",
"memory",
"system",
"tpcds",
"tpch",
]
}

View File

@@ -44,10 +44,10 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
database.database_name = "my_db"
database.db_engine_spec.__name__ = "test_engine"
database.db_engine_spec.supports_catalog = True
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
database.get_all_catalog_names.return_value = {"catalog1", "catalog2"}
database.get_all_schema_names.side_effect = [
["schema1", "schema2"],
["schema3", "schema4"],
{"schema1", "schema2"},
{"schema3", "schema4"},
]
database.get_default_catalog.return_value = "catalog2"
@@ -63,7 +63,7 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
database.database_name = "my_db"
database.db_engine_spec.__name__ = "test_engine"
database.db_engine_spec.supports_catalog = False
database.get_all_schema_names.return_value = ["schema1", "schema2"]
database.get_all_schema_names.return_value = {"schema1", "schema2"}
database.is_oauth2_enabled.return_value = False
database.db_engine_spec.needs_oauth2.return_value = False

View File

@@ -69,23 +69,24 @@ def test_sync_permissions_command_sync_mode(
add_pvm_mock.assert_has_calls(
[
mocker.call(
db.session, security_manager, "catalog_access", "[my_db].[catalog2]"
db.session, security_manager, "catalog_access", "[my_db].[catalog1]"
),
mocker.call(
db.session,
security_manager,
"schema_access",
"[my_db].[catalog2].[schema3]",
"[my_db].[catalog1].[schema3]",
),
mocker.call(
db.session,
security_manager,
"schema_access",
"[my_db].[catalog2].[schema4]",
"[my_db].[catalog1].[schema4]",
),
]
],
any_order=True,
)
mock_refresh_schemas.assert_called_once_with("catalog1", ["schema1", "schema2"])
mock_refresh_schemas.assert_called_once_with("catalog2", {"schema1", "schema2"})
mock_rename_db_perm.assert_not_called()
@@ -246,7 +247,7 @@ def test_sync_permissions_command_get_catalogs(database_with_catalog: MagicMock)
Test the ``_get_catalog_names`` method.
"""
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_catalog_names() == ["catalog1", "catalog2"]
assert cmmd._get_catalog_names() == {"catalog1", "catalog2"}
def test_sync_permissions_command_get_default_catalog(database_with_catalog: MagicMock):
@@ -263,7 +264,7 @@ def test_sync_permissions_command_get_default_catalog(database_with_catalog: Mag
database_with_catalog.allow_multi_catalog = True
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_catalog_names() == ["catalog1", "catalog2"]
assert cmmd._get_catalog_names() == {"catalog1", "catalog2"}
@pytest.mark.parametrize(
@@ -295,8 +296,8 @@ def test_sync_permissions_command_get_schemas(database_with_catalog: MagicMock):
Test the ``_get_schema_names`` method.
"""
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_schema_names("catalog1") == ["schema1", "schema2"]
assert cmmd._get_schema_names("catalog2") == ["schema3", "schema4"]
assert cmmd._get_schema_names("catalog1") == {"schema1", "schema2"}
assert cmmd._get_schema_names("catalog2") == {"schema3", "schema4"}
@pytest.mark.parametrize(

View File

@@ -0,0 +1,195 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from sqlalchemy.exc import SQLAlchemyError
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils.core import QuerySource
from tests.integration_tests.base_tests import SupersetTestCase
class TestBaseEngineSpecMetadata(SupersetTestCase):
@mock.patch.object(BaseEngineSpec, "get_engine")
def test_execute_metadata_query_basic(self, mock_get_engine):
"""Test basic metadata query execution"""
database = mock.Mock()
mock_engine = mock.Mock()
mock_connection = mock.Mock()
mock_result = mock.Mock()
# Setup mock chain
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
mock_engine.connect.return_value.__enter__ = mock.Mock(
return_value=mock_connection
)
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
expected_results = [("catalog1",), ("catalog2",)]
mock_result.fetchall.return_value = expected_results
mock_result.__iter__ = lambda self: iter(expected_results)
mock_connection.execute.return_value = mock_result
# Execute the method
query = "SHOW CATALOGS"
results = BaseEngineSpec.execute_metadata_query(database, query)
# Verify results - now we get the Result object itself
assert results == mock_result
# Test that we can iterate over it like the real usage
assert list(results) == expected_results
# Verify call chain
mock_get_engine.assert_called_once_with(
database, catalog=None, schema=None, source=QuerySource.METADATA
)
mock_connection.execute.assert_called_once()
executed_query = mock_connection.execute.call_args[0][0]
assert str(executed_query) == query
@mock.patch.object(BaseEngineSpec, "get_engine")
def test_execute_metadata_query_with_catalog_schema(self, mock_get_engine):
"""Test metadata query with catalog and schema parameters"""
database = mock.Mock()
mock_engine = mock.Mock()
mock_connection = mock.Mock()
mock_result = mock.Mock()
# Setup mock chain
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
mock_engine.connect.return_value.__enter__ = mock.Mock(
return_value=mock_connection
)
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
expected_results = [("table1",), ("table2",)]
mock_result.fetchall.return_value = expected_results
mock_result.__iter__ = lambda self: iter(expected_results)
mock_connection.execute.return_value = mock_result
# Execute with catalog and schema
query = "SHOW TABLES"
catalog = "my_catalog"
schema = "my_schema"
results = BaseEngineSpec.execute_metadata_query(
database, query, catalog=catalog, schema=schema
)
# Verify results - now we get the Result object itself
assert results == mock_result
# Test that we can iterate over it like the real usage
assert list(results) == expected_results
# Verify catalog and schema were passed to get_engine
mock_get_engine.assert_called_once_with(
database, catalog=catalog, schema=schema, source=QuerySource.METADATA
)
@mock.patch.object(BaseEngineSpec, "get_engine")
def test_execute_metadata_query_uses_correct_query_source(self, mock_get_engine):
"""Test that QuerySource.METADATA is used correctly"""
database = mock.Mock()
mock_engine = mock.Mock()
mock_connection = mock.Mock()
mock_result = mock.Mock()
# Setup mock chain
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
mock_engine.connect.return_value.__enter__ = mock.Mock(
return_value=mock_connection
)
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
empty_results = []
mock_result.fetchall.return_value = empty_results
mock_result.__iter__ = lambda self: iter(empty_results)
mock_connection.execute.return_value = mock_result
# Execute the method
BaseEngineSpec.execute_metadata_query(database, "SELECT 1")
# Verify QuerySource.METADATA was used
mock_get_engine.assert_called_once_with(
database, catalog=None, schema=None, source=QuerySource.METADATA
)
@mock.patch.object(BaseEngineSpec, "get_engine")
def test_execute_metadata_query_handles_sql_error(self, mock_get_engine):
"""Test proper exception handling for SQL errors"""
database = mock.Mock()
mock_engine = mock.Mock()
mock_connection = mock.Mock()
# Setup mock chain
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
mock_engine.connect.return_value.__enter__ = mock.Mock(
return_value=mock_connection
)
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
# Mock connection to raise SQLAlchemyError
mock_connection.execute.side_effect = SQLAlchemyError("Database error")
# Execute and verify exception is propagated
with pytest.raises(SQLAlchemyError):
BaseEngineSpec.execute_metadata_query(database, "INVALID QUERY")
@mock.patch.object(BaseEngineSpec, "get_engine")
def test_execute_metadata_query_empty_results(self, mock_get_engine):
"""Test handling of queries that return no results"""
database = mock.Mock()
mock_engine = mock.Mock()
mock_connection = mock.Mock()
mock_result = mock.Mock()
# Setup mock chain
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
mock_engine.connect.return_value.__enter__ = mock.Mock(
return_value=mock_connection
)
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
# Empty results
empty_results = []
mock_result.fetchall.return_value = empty_results
mock_result.__iter__ = lambda self: iter(empty_results)
mock_connection.execute.return_value = mock_result
# Execute the method
results = BaseEngineSpec.execute_metadata_query(database, "SELECT 1 WHERE 1=0")
# Verify empty results are handled correctly - now we get the Result object
assert results == mock_result
# Test that iterating over it gives empty results like the real usage
assert list(results) == []
def test_execute_metadata_query_query_source_enum_value(self):
"""Test QuerySource.METADATA enum has correct value"""
# This is a simple enum test that doesn't require complex mocking
assert QuerySource.METADATA.value == 3
assert QuerySource.METADATA.name == "METADATA"

View File

@@ -1122,3 +1122,21 @@ def test_get_stacktrace():
except Exception:
stacktrace = get_stacktrace()
assert stacktrace is None
def test_query_source_metadata_enum():
"""Test that QuerySource.METADATA enum value exists and has correct value"""
assert hasattr(QuerySource, "METADATA")
assert QuerySource.METADATA.value == 3
assert QuerySource.METADATA.name == "METADATA"
# Verify all expected QuerySource values
expected_sources = {
QuerySource.CHART: 0,
QuerySource.DASHBOARD: 1,
QuerySource.SQL_LAB: 2,
QuerySource.METADATA: 3,
}
for source, expected_value in expected_sources.items():
assert source.value == expected_value