mirror of
https://github.com/apache/superset.git
synced 2026-04-28 12:34:23 +00:00
Compare commits
7 Commits
amin/execu
...
standardiz
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dfe73d19c | ||
|
|
bbda5e2008 | ||
|
|
53999c12dd | ||
|
|
f554036d29 | ||
|
|
33e7932491 | ||
|
|
92b02d993b | ||
|
|
72ba972e42 |
@@ -49,6 +49,7 @@ from flask_babel import gettext as __, lazy_gettext as _
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.validate import Range
|
||||
from sqlalchemy import column, select, types
|
||||
from sqlalchemy.engine import Result
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.interfaces import Compiled, Dialect
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
@@ -1670,14 +1671,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
|
||||
@classmethod
|
||||
def estimate_statement_cost(
|
||||
cls, database: Database, statement: str, cursor: Any
|
||||
cls, database: Database, statement: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a SQL query that estimates the cost of a given statement.
|
||||
|
||||
:param database: A Database object
|
||||
:param statement: A single SQL statement
|
||||
:param cursor: Cursor instance
|
||||
:return: Dictionary with different costs
|
||||
"""
|
||||
raise Exception( # pylint: disable=broad-exception-raised
|
||||
@@ -1738,20 +1738,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
|
||||
parsed_script = SQLScript(sql, engine=cls.engine)
|
||||
|
||||
with database.get_raw_connection(
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
source=source,
|
||||
) as conn:
|
||||
cursor = conn.cursor()
|
||||
return [
|
||||
cls.estimate_statement_cost(
|
||||
database,
|
||||
cls.process_statement(statement, database),
|
||||
cursor,
|
||||
)
|
||||
for statement in parsed_script.statements
|
||||
]
|
||||
return [
|
||||
cls.estimate_statement_cost(
|
||||
database,
|
||||
cls.process_statement(statement, database),
|
||||
)
|
||||
for statement in parsed_script.statements
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def impersonate_user(
|
||||
@@ -1854,6 +1847,42 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
cls.start_oauth2_dance(database)
|
||||
raise cls.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
@classmethod
|
||||
def execute_metadata_query(
|
||||
cls,
|
||||
database: Database,
|
||||
query: str,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> Result:
|
||||
"""
|
||||
Standardized method for executing metadata queries.
|
||||
|
||||
This method provides a unified interface for all metadata query operations
|
||||
across different database engines using SQLAlchemy connections.
|
||||
|
||||
For single-row results, add "LIMIT 1" to your query rather than using
|
||||
separate fetch parameters.
|
||||
|
||||
:param database: Database instance
|
||||
:param query: SQL query to execute for metadata
|
||||
:param catalog: Optional catalog/database name
|
||||
:param schema: Optional schema name
|
||||
:return: SQLAlchemy Result object with methods like:
|
||||
- result.fetchall() -> list[Row]: Get all rows
|
||||
- result.fetchone() -> Row | None: Get single row
|
||||
- result.scalar() -> Any: Get single value
|
||||
- result.mappings() -> mappings for dict-like access
|
||||
"""
|
||||
with cls.get_engine(
|
||||
database,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
source=utils.QuerySource.METADATA,
|
||||
) as engine:
|
||||
with engine.connect() as conn:
|
||||
return conn.execute(text(query))
|
||||
|
||||
@classmethod
|
||||
def needs_oauth2(cls, ex: Exception) -> bool:
|
||||
"""
|
||||
|
||||
@@ -331,9 +331,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
||||
)
|
||||
|
||||
# Build the query
|
||||
query = select(
|
||||
func.max(partitions_table.c.partition_id).label("max_partition_id")
|
||||
).where(partitions_table.c.table_name == table.table)
|
||||
query = (
|
||||
select(func.max(partitions_table.c.partition_id).label("max_partition_id"))
|
||||
.where(partitions_table.c.table_name == table.table)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# Compile to BigQuery SQL
|
||||
compiled_query = query.compile(
|
||||
@@ -342,15 +344,13 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
||||
)
|
||||
|
||||
# Run the query and handle result
|
||||
with database.get_raw_connection(
|
||||
result = cls.execute_metadata_query(
|
||||
database,
|
||||
str(compiled_query),
|
||||
catalog=table.catalog,
|
||||
schema=table.schema,
|
||||
) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(str(compiled_query))
|
||||
if row := cursor.fetchone():
|
||||
return row[0]
|
||||
return None
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
@classmethod
|
||||
def get_time_partition_column(
|
||||
|
||||
@@ -503,12 +503,13 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||
if default_catalog := connect_args.get("catalog"):
|
||||
return default_catalog
|
||||
|
||||
with database.get_sqla_engine() as engine:
|
||||
catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")}
|
||||
if len(catalogs) == 1:
|
||||
return catalogs.pop()
|
||||
result = cls.execute_metadata_query(database, "SHOW CATALOGS")
|
||||
catalogs = {catalog for (catalog,) in result}
|
||||
if len(catalogs) == 1:
|
||||
return catalogs.pop()
|
||||
|
||||
return engine.execute("SELECT current_catalog()").scalar()
|
||||
result = cls.execute_metadata_query(database, "SELECT current_catalog()")
|
||||
return result.scalar()
|
||||
|
||||
@classmethod
|
||||
def get_prequeries(
|
||||
@@ -532,7 +533,8 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
) -> set[str]:
|
||||
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
|
||||
results = cls.execute_metadata_query(database, "SHOW CATALOGS")
|
||||
return {catalog for (catalog,) in results}
|
||||
|
||||
|
||||
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||
@@ -624,7 +626,8 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
) -> set[str]:
|
||||
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
|
||||
results = cls.execute_metadata_query(database, "SHOW CATALOGS")
|
||||
return {catalog for (catalog,) in results}
|
||||
|
||||
@classmethod
|
||||
def adjust_engine_params(
|
||||
|
||||
@@ -297,8 +297,8 @@ class DorisEngineSpec(MySQLEngineSpec):
|
||||
CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
|
||||
We need to extract just the CatalogName column.
|
||||
"""
|
||||
result = inspector.bind.execute("SHOW CATALOGS")
|
||||
return {row.CatalogName for row in result}
|
||||
results = cls.execute_metadata_query(database, "SHOW CATALOGS")
|
||||
return {row.CatalogName for row in results}
|
||||
|
||||
@classmethod
|
||||
def get_schema_from_engine_params(
|
||||
|
||||
@@ -388,9 +388,8 @@ class MotherDuckEngineSpec(DuckDBEngineSpec):
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
) -> set[str]:
|
||||
return {
|
||||
catalog
|
||||
for (catalog,) in inspector.bind.execute(
|
||||
"SELECT alias FROM MD_ALL_DATABASES() WHERE is_attached;"
|
||||
)
|
||||
}
|
||||
results = cls.execute_metadata_query(
|
||||
database,
|
||||
"SELECT alias FROM MD_ALL_DATABASES() WHERE is_attached;",
|
||||
)
|
||||
return {catalog for (catalog,) in results}
|
||||
|
||||
@@ -154,15 +154,16 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
||||
database: Database,
|
||||
table: Table,
|
||||
) -> dict[str, Any]:
|
||||
with database.get_raw_connection(
|
||||
result = cls.execute_metadata_query(
|
||||
database,
|
||||
f'SELECT GET_METADATA("{table.table}") LIMIT 1',
|
||||
catalog=table.catalog,
|
||||
schema=table.schema,
|
||||
) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f'SELECT GET_METADATA("{table.table}")')
|
||||
results = cursor.fetchone()[0]
|
||||
)
|
||||
results_list = result.fetchall()
|
||||
results = results_list[0][0] if results_list else None
|
||||
try:
|
||||
metadata = json.loads(results)
|
||||
metadata = json.loads(results) if results else {}
|
||||
except Exception: # pylint: disable=broad-except
|
||||
metadata = {}
|
||||
|
||||
|
||||
@@ -614,8 +614,9 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
if schema:
|
||||
sql += f" IN `{schema}`"
|
||||
|
||||
with database.get_raw_connection(schema=schema) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
return {row[0] for row in results}
|
||||
result = cls.execute_metadata_query(
|
||||
database,
|
||||
sql,
|
||||
schema=schema,
|
||||
)
|
||||
return {row[0] for row in result}
|
||||
|
||||
@@ -354,19 +354,18 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
||||
|
||||
@classmethod
|
||||
def estimate_statement_cost(
|
||||
cls, database: Database, statement: str, cursor: Any
|
||||
cls, database: Database, statement: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run a SQL query that estimates the cost of a given statement.
|
||||
:param database: A Database object
|
||||
:param statement: A single SQL statement
|
||||
:param cursor: Cursor instance
|
||||
:return: JSON response from Trino
|
||||
:return: Cost estimate dictionary
|
||||
"""
|
||||
sql = f"EXPLAIN {statement}"
|
||||
cursor.execute(sql)
|
||||
|
||||
result = cursor.fetchone()[0]
|
||||
sql = f"EXPLAIN {statement} LIMIT 1"
|
||||
results = cls.execute_metadata_query(database, sql)
|
||||
rows = results.fetchall()
|
||||
result = rows[0][0] if rows else ""
|
||||
match = re.search(r"cost=([\d\.]+)\.\.([\d\.]+)", result)
|
||||
if match:
|
||||
return {
|
||||
@@ -393,15 +392,14 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
||||
|
||||
In Postgres, a catalog is called a "database".
|
||||
"""
|
||||
return {
|
||||
catalog
|
||||
for (catalog,) in inspector.bind.execute(
|
||||
"""
|
||||
results = cls.execute_metadata_query(
|
||||
database,
|
||||
"""
|
||||
SELECT datname FROM pg_database
|
||||
WHERE datistemplate = false;
|
||||
"""
|
||||
)
|
||||
}
|
||||
""",
|
||||
)
|
||||
return {catalog for (catalog,) in results}
|
||||
|
||||
@classmethod
|
||||
def get_table_names(
|
||||
|
||||
@@ -321,7 +321,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
"""
|
||||
Get all catalogs.
|
||||
"""
|
||||
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
|
||||
result = cls.execute_metadata_query(database, "SHOW CATALOGS")
|
||||
return {catalog for (catalog,) in result}
|
||||
|
||||
@classmethod
|
||||
def adjust_engine_params(
|
||||
@@ -373,17 +374,16 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
|
||||
@classmethod
|
||||
def estimate_statement_cost(
|
||||
cls, database: Database, statement: str, cursor: Any
|
||||
cls, database: Database, statement: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run a SQL query that estimates the cost of a given statement.
|
||||
:param database: A Database object
|
||||
:param statement: A single SQL statement
|
||||
:param cursor: Cursor instance
|
||||
:return: JSON response from Trino
|
||||
"""
|
||||
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}"
|
||||
cursor.execute(sql)
|
||||
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement} LIMIT 1"
|
||||
results = cls.execute_metadata_query(database, sql)
|
||||
|
||||
# the output from Trino is a single column and a single row containing
|
||||
# JSON:
|
||||
@@ -398,7 +398,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
||||
# "networkCost" : 3.41425774958E11
|
||||
# }
|
||||
# }
|
||||
result = json.loads(cursor.fetchone()[0])
|
||||
result = json.loads(results[0][0]) if results else {}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@@ -1037,7 +1037,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
AND table_type = 'VIEW'
|
||||
"""
|
||||
).strip()
|
||||
params = {"schema": schema}
|
||||
results = inspector.bind.execute(sql, {"schema": schema}).fetchall()
|
||||
else:
|
||||
sql = dedent(
|
||||
"""
|
||||
@@ -1045,13 +1045,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
WHERE table_type = 'VIEW'
|
||||
"""
|
||||
).strip()
|
||||
params = {}
|
||||
results = inspector.bind.execute(sql).fetchall()
|
||||
|
||||
with database.get_raw_connection(schema=schema) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params)
|
||||
results = cursor.fetchall()
|
||||
return {row[0] for row in results}
|
||||
return {row[0] for row in results}
|
||||
|
||||
@classmethod
|
||||
def _is_column_name_quoted(cls, column_name: str) -> bool:
|
||||
@@ -1299,16 +1295,12 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from pyhive.exc import DatabaseError
|
||||
|
||||
with database.get_raw_connection(schema=schema) as conn:
|
||||
cursor = conn.cursor()
|
||||
sql = f"SHOW CREATE VIEW {schema}.{table}"
|
||||
try:
|
||||
cls.execute(cursor, sql, database)
|
||||
rows = cls.fetch_data(cursor, 1)
|
||||
|
||||
return rows[0][0]
|
||||
except DatabaseError: # not a VIEW
|
||||
return None
|
||||
sql = f"SHOW CREATE VIEW {schema}.{table} LIMIT 1"
|
||||
try:
|
||||
results = cls.execute_metadata_query(database, sql, schema=schema)
|
||||
return results[0][0] if results else None
|
||||
except DatabaseError: # not a VIEW
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_tracking_url(cls, cursor: Cursor) -> str | None:
|
||||
|
||||
@@ -209,12 +209,11 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
|
||||
In Snowflake, a catalog is called a "database".
|
||||
"""
|
||||
return {
|
||||
catalog
|
||||
for (catalog,) in inspector.bind.execute(
|
||||
"SELECT DATABASE_NAME from information_schema.databases"
|
||||
)
|
||||
}
|
||||
results = cls.execute_metadata_query(
|
||||
database,
|
||||
"SELECT DATABASE_NAME from information_schema.databases",
|
||||
)
|
||||
return {catalog for (catalog,) in results}
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls) -> str:
|
||||
|
||||
@@ -320,6 +320,7 @@ class QuerySource(Enum):
|
||||
CHART = 0
|
||||
DASHBOARD = 1
|
||||
SQL_LAB = 2
|
||||
METADATA = 3
|
||||
|
||||
|
||||
class QueryStatus(StrEnum):
|
||||
|
||||
@@ -152,13 +152,22 @@ class TestPostgresDbEngineSpec(SupersetTestCase):
|
||||
"""
|
||||
|
||||
database = mock.Mock()
|
||||
cursor = mock.Mock()
|
||||
cursor.fetchone.return_value = (
|
||||
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
|
||||
)
|
||||
sql = "SELECT * FROM birth_names"
|
||||
results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
|
||||
|
||||
# Mock the execute_metadata_query method to return expected results
|
||||
with mock.patch.object(
|
||||
PostgresEngineSpec, "execute_metadata_query"
|
||||
) as mock_execute:
|
||||
mock_result = mock.Mock()
|
||||
expected_results = [
|
||||
("Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",)
|
||||
]
|
||||
mock_result.fetchall.return_value = expected_results
|
||||
mock_execute.return_value = mock_result
|
||||
results = PostgresEngineSpec.estimate_statement_cost(database, sql)
|
||||
|
||||
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
|
||||
mock_execute.assert_called_once_with(database, f"EXPLAIN {sql} LIMIT 1")
|
||||
|
||||
def test_estimate_statement_invalid_syntax(self):
|
||||
"""
|
||||
@@ -167,17 +176,21 @@ class TestPostgresDbEngineSpec(SupersetTestCase):
|
||||
from psycopg2 import errors
|
||||
|
||||
database = mock.Mock()
|
||||
cursor = mock.Mock()
|
||||
cursor.execute.side_effect = errors.SyntaxError(
|
||||
"""
|
||||
syntax error at or near "EXPLAIN"
|
||||
LINE 1: EXPLAIN DROP TABLE birth_names
|
||||
^
|
||||
"""
|
||||
)
|
||||
sql = "DROP TABLE birth_names"
|
||||
with self.assertRaises(errors.SyntaxError): # noqa: PT027
|
||||
PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
|
||||
|
||||
# Mock the execute_metadata_query method to raise the expected exception
|
||||
with mock.patch.object(
|
||||
PostgresEngineSpec, "execute_metadata_query"
|
||||
) as mock_execute:
|
||||
mock_execute.side_effect = errors.SyntaxError(
|
||||
"""
|
||||
syntax error at or near "EXPLAIN"
|
||||
LINE 1: EXPLAIN DROP TABLE birth_names
|
||||
^
|
||||
"""
|
||||
)
|
||||
with self.assertRaises(errors.SyntaxError): # noqa: PT027
|
||||
PostgresEngineSpec.estimate_statement_cost(database, sql)
|
||||
|
||||
def test_query_cost_formatter_example_costs(self):
|
||||
"""
|
||||
|
||||
@@ -931,26 +931,34 @@ class TestPrestoDbEngineSpec(SupersetTestCase):
|
||||
|
||||
def test_estimate_statement_cost(self):
|
||||
mock_database = mock.MagicMock()
|
||||
mock_cursor = mock.MagicMock()
|
||||
estimate_json = {"a": "b"}
|
||||
mock_cursor.fetchone.return_value = [
|
||||
'{"a": "b"}',
|
||||
]
|
||||
result = PrestoEngineSpec.estimate_statement_cost(
|
||||
mock_database,
|
||||
"SELECT * FROM brth_names",
|
||||
mock_cursor,
|
||||
)
|
||||
sql = "SELECT * FROM brth_names"
|
||||
|
||||
# Mock the execute_metadata_query method to return expected JSON results
|
||||
with mock.patch.object(
|
||||
PrestoEngineSpec, "execute_metadata_query"
|
||||
) as mock_execute:
|
||||
mock_result = mock.Mock()
|
||||
mock_result.scalar.return_value = '{"a": "b"}'
|
||||
mock_execute.return_value = mock_result
|
||||
result = PrestoEngineSpec.estimate_statement_cost(mock_database, sql)
|
||||
|
||||
assert result == estimate_json
|
||||
mock_execute.assert_called_once_with(
|
||||
mock_database, f"EXPLAIN (TYPE IO, FORMAT JSON) {sql} LIMIT 1"
|
||||
)
|
||||
|
||||
def test_estimate_statement_cost_invalid_syntax(self):
|
||||
mock_database = mock.MagicMock()
|
||||
mock_cursor = mock.MagicMock()
|
||||
mock_cursor.execute.side_effect = Exception()
|
||||
with self.assertRaises(Exception): # noqa: B017, PT027
|
||||
PrestoEngineSpec.estimate_statement_cost(
|
||||
mock_database, "DROP TABLE brth_names", mock_cursor
|
||||
)
|
||||
sql = "DROP TABLE brth_names"
|
||||
|
||||
# Mock the execute_metadata_query method to raise an exception
|
||||
with mock.patch.object(
|
||||
PrestoEngineSpec, "execute_metadata_query"
|
||||
) as mock_execute:
|
||||
mock_execute.side_effect = Exception("Invalid syntax")
|
||||
with self.assertRaises(Exception): # noqa: B017, PT027
|
||||
PrestoEngineSpec.estimate_statement_cost(mock_database, sql)
|
||||
|
||||
def test_get_create_view(self):
|
||||
mock_execute = mock.MagicMock()
|
||||
@@ -1207,10 +1215,10 @@ def test_get_catalog_names(app_context: AppContext) -> None:
|
||||
return
|
||||
|
||||
with database.get_inspector() as inspector:
|
||||
assert PrestoEngineSpec.get_catalog_names(database, inspector) == [
|
||||
assert PrestoEngineSpec.get_catalog_names(database, inspector) == {
|
||||
"jmx",
|
||||
"memory",
|
||||
"system",
|
||||
"tpcds",
|
||||
"tpch",
|
||||
]
|
||||
}
|
||||
|
||||
@@ -44,10 +44,10 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
|
||||
database.database_name = "my_db"
|
||||
database.db_engine_spec.__name__ = "test_engine"
|
||||
database.db_engine_spec.supports_catalog = True
|
||||
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
|
||||
database.get_all_catalog_names.return_value = {"catalog1", "catalog2"}
|
||||
database.get_all_schema_names.side_effect = [
|
||||
["schema1", "schema2"],
|
||||
["schema3", "schema4"],
|
||||
{"schema1", "schema2"},
|
||||
{"schema3", "schema4"},
|
||||
]
|
||||
database.get_default_catalog.return_value = "catalog2"
|
||||
|
||||
@@ -63,7 +63,7 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
|
||||
database.database_name = "my_db"
|
||||
database.db_engine_spec.__name__ = "test_engine"
|
||||
database.db_engine_spec.supports_catalog = False
|
||||
database.get_all_schema_names.return_value = ["schema1", "schema2"]
|
||||
database.get_all_schema_names.return_value = {"schema1", "schema2"}
|
||||
database.is_oauth2_enabled.return_value = False
|
||||
database.db_engine_spec.needs_oauth2.return_value = False
|
||||
|
||||
|
||||
@@ -69,23 +69,24 @@ def test_sync_permissions_command_sync_mode(
|
||||
add_pvm_mock.assert_has_calls(
|
||||
[
|
||||
mocker.call(
|
||||
db.session, security_manager, "catalog_access", "[my_db].[catalog2]"
|
||||
db.session, security_manager, "catalog_access", "[my_db].[catalog1]"
|
||||
),
|
||||
mocker.call(
|
||||
db.session,
|
||||
security_manager,
|
||||
"schema_access",
|
||||
"[my_db].[catalog2].[schema3]",
|
||||
"[my_db].[catalog1].[schema3]",
|
||||
),
|
||||
mocker.call(
|
||||
db.session,
|
||||
security_manager,
|
||||
"schema_access",
|
||||
"[my_db].[catalog2].[schema4]",
|
||||
"[my_db].[catalog1].[schema4]",
|
||||
),
|
||||
]
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
mock_refresh_schemas.assert_called_once_with("catalog1", ["schema1", "schema2"])
|
||||
mock_refresh_schemas.assert_called_once_with("catalog2", {"schema1", "schema2"})
|
||||
mock_rename_db_perm.assert_not_called()
|
||||
|
||||
|
||||
@@ -246,7 +247,7 @@ def test_sync_permissions_command_get_catalogs(database_with_catalog: MagicMock)
|
||||
Test the ``_get_catalog_names`` method.
|
||||
"""
|
||||
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
|
||||
assert cmmd._get_catalog_names() == ["catalog1", "catalog2"]
|
||||
assert cmmd._get_catalog_names() == {"catalog1", "catalog2"}
|
||||
|
||||
|
||||
def test_sync_permissions_command_get_default_catalog(database_with_catalog: MagicMock):
|
||||
@@ -263,7 +264,7 @@ def test_sync_permissions_command_get_default_catalog(database_with_catalog: Mag
|
||||
|
||||
database_with_catalog.allow_multi_catalog = True
|
||||
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
|
||||
assert cmmd._get_catalog_names() == ["catalog1", "catalog2"]
|
||||
assert cmmd._get_catalog_names() == {"catalog1", "catalog2"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -295,8 +296,8 @@ def test_sync_permissions_command_get_schemas(database_with_catalog: MagicMock):
|
||||
Test the ``_get_schema_names`` method.
|
||||
"""
|
||||
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
|
||||
assert cmmd._get_schema_names("catalog1") == ["schema1", "schema2"]
|
||||
assert cmmd._get_schema_names("catalog2") == ["schema3", "schema4"]
|
||||
assert cmmd._get_schema_names("catalog1") == {"schema1", "schema2"}
|
||||
assert cmmd._get_schema_names("catalog2") == {"schema3", "schema4"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
195
tests/unit_tests/db_engine_specs/test_base_metadata.py
Normal file
195
tests/unit_tests/db_engine_specs/test_base_metadata.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.utils.core import QuerySource
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class TestBaseEngineSpecMetadata(SupersetTestCase):
|
||||
@mock.patch.object(BaseEngineSpec, "get_engine")
|
||||
def test_execute_metadata_query_basic(self, mock_get_engine):
|
||||
"""Test basic metadata query execution"""
|
||||
database = mock.Mock()
|
||||
mock_engine = mock.Mock()
|
||||
mock_connection = mock.Mock()
|
||||
mock_result = mock.Mock()
|
||||
|
||||
# Setup mock chain
|
||||
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
|
||||
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
|
||||
|
||||
mock_engine.connect.return_value.__enter__ = mock.Mock(
|
||||
return_value=mock_connection
|
||||
)
|
||||
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
|
||||
|
||||
expected_results = [("catalog1",), ("catalog2",)]
|
||||
mock_result.fetchall.return_value = expected_results
|
||||
mock_result.__iter__ = lambda self: iter(expected_results)
|
||||
mock_connection.execute.return_value = mock_result
|
||||
|
||||
# Execute the method
|
||||
query = "SHOW CATALOGS"
|
||||
results = BaseEngineSpec.execute_metadata_query(database, query)
|
||||
|
||||
# Verify results - now we get the Result object itself
|
||||
assert results == mock_result
|
||||
# Test that we can iterate over it like the real usage
|
||||
assert list(results) == expected_results
|
||||
|
||||
# Verify call chain
|
||||
mock_get_engine.assert_called_once_with(
|
||||
database, catalog=None, schema=None, source=QuerySource.METADATA
|
||||
)
|
||||
mock_connection.execute.assert_called_once()
|
||||
executed_query = mock_connection.execute.call_args[0][0]
|
||||
assert str(executed_query) == query
|
||||
|
||||
@mock.patch.object(BaseEngineSpec, "get_engine")
|
||||
def test_execute_metadata_query_with_catalog_schema(self, mock_get_engine):
|
||||
"""Test metadata query with catalog and schema parameters"""
|
||||
database = mock.Mock()
|
||||
mock_engine = mock.Mock()
|
||||
mock_connection = mock.Mock()
|
||||
mock_result = mock.Mock()
|
||||
|
||||
# Setup mock chain
|
||||
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
|
||||
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
|
||||
|
||||
mock_engine.connect.return_value.__enter__ = mock.Mock(
|
||||
return_value=mock_connection
|
||||
)
|
||||
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
|
||||
|
||||
expected_results = [("table1",), ("table2",)]
|
||||
mock_result.fetchall.return_value = expected_results
|
||||
mock_result.__iter__ = lambda self: iter(expected_results)
|
||||
mock_connection.execute.return_value = mock_result
|
||||
|
||||
# Execute with catalog and schema
|
||||
query = "SHOW TABLES"
|
||||
catalog = "my_catalog"
|
||||
schema = "my_schema"
|
||||
results = BaseEngineSpec.execute_metadata_query(
|
||||
database, query, catalog=catalog, schema=schema
|
||||
)
|
||||
|
||||
# Verify results - now we get the Result object itself
|
||||
assert results == mock_result
|
||||
# Test that we can iterate over it like the real usage
|
||||
assert list(results) == expected_results
|
||||
|
||||
# Verify catalog and schema were passed to get_engine
|
||||
mock_get_engine.assert_called_once_with(
|
||||
database, catalog=catalog, schema=schema, source=QuerySource.METADATA
|
||||
)
|
||||
|
||||
@mock.patch.object(BaseEngineSpec, "get_engine")
|
||||
def test_execute_metadata_query_uses_correct_query_source(self, mock_get_engine):
|
||||
"""Test that QuerySource.METADATA is used correctly"""
|
||||
database = mock.Mock()
|
||||
mock_engine = mock.Mock()
|
||||
mock_connection = mock.Mock()
|
||||
mock_result = mock.Mock()
|
||||
|
||||
# Setup mock chain
|
||||
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
|
||||
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
|
||||
|
||||
mock_engine.connect.return_value.__enter__ = mock.Mock(
|
||||
return_value=mock_connection
|
||||
)
|
||||
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
|
||||
|
||||
empty_results = []
|
||||
mock_result.fetchall.return_value = empty_results
|
||||
mock_result.__iter__ = lambda self: iter(empty_results)
|
||||
mock_connection.execute.return_value = mock_result
|
||||
|
||||
# Execute the method
|
||||
BaseEngineSpec.execute_metadata_query(database, "SELECT 1")
|
||||
|
||||
# Verify QuerySource.METADATA was used
|
||||
mock_get_engine.assert_called_once_with(
|
||||
database, catalog=None, schema=None, source=QuerySource.METADATA
|
||||
)
|
||||
|
||||
@mock.patch.object(BaseEngineSpec, "get_engine")
|
||||
def test_execute_metadata_query_handles_sql_error(self, mock_get_engine):
|
||||
"""Test proper exception handling for SQL errors"""
|
||||
database = mock.Mock()
|
||||
mock_engine = mock.Mock()
|
||||
mock_connection = mock.Mock()
|
||||
|
||||
# Setup mock chain
|
||||
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
|
||||
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
|
||||
|
||||
mock_engine.connect.return_value.__enter__ = mock.Mock(
|
||||
return_value=mock_connection
|
||||
)
|
||||
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
|
||||
|
||||
# Mock connection to raise SQLAlchemyError
|
||||
mock_connection.execute.side_effect = SQLAlchemyError("Database error")
|
||||
|
||||
# Execute and verify exception is propagated
|
||||
with pytest.raises(SQLAlchemyError):
|
||||
BaseEngineSpec.execute_metadata_query(database, "INVALID QUERY")
|
||||
|
||||
@mock.patch.object(BaseEngineSpec, "get_engine")
|
||||
def test_execute_metadata_query_empty_results(self, mock_get_engine):
|
||||
"""Test handling of queries that return no results"""
|
||||
database = mock.Mock()
|
||||
mock_engine = mock.Mock()
|
||||
mock_connection = mock.Mock()
|
||||
mock_result = mock.Mock()
|
||||
|
||||
# Setup mock chain
|
||||
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
|
||||
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
|
||||
|
||||
mock_engine.connect.return_value.__enter__ = mock.Mock(
|
||||
return_value=mock_connection
|
||||
)
|
||||
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
|
||||
|
||||
# Empty results
|
||||
empty_results = []
|
||||
mock_result.fetchall.return_value = empty_results
|
||||
mock_result.__iter__ = lambda self: iter(empty_results)
|
||||
mock_connection.execute.return_value = mock_result
|
||||
|
||||
# Execute the method
|
||||
results = BaseEngineSpec.execute_metadata_query(database, "SELECT 1 WHERE 1=0")
|
||||
|
||||
# Verify empty results are handled correctly - now we get the Result object
|
||||
assert results == mock_result
|
||||
# Test that iterating over it gives empty results like the real usage
|
||||
assert list(results) == []
|
||||
|
||||
def test_execute_metadata_query_query_source_enum_value(self):
|
||||
"""Test QuerySource.METADATA enum has correct value"""
|
||||
# This is a simple enum test that doesn't require complex mocking
|
||||
assert QuerySource.METADATA.value == 3
|
||||
assert QuerySource.METADATA.name == "METADATA"
|
||||
@@ -1122,3 +1122,21 @@ def test_get_stacktrace():
|
||||
except Exception:
|
||||
stacktrace = get_stacktrace()
|
||||
assert stacktrace is None
|
||||
|
||||
|
||||
def test_query_source_metadata_enum():
|
||||
"""Test that QuerySource.METADATA enum value exists and has correct value"""
|
||||
assert hasattr(QuerySource, "METADATA")
|
||||
assert QuerySource.METADATA.value == 3
|
||||
assert QuerySource.METADATA.name == "METADATA"
|
||||
|
||||
# Verify all expected QuerySource values
|
||||
expected_sources = {
|
||||
QuerySource.CHART: 0,
|
||||
QuerySource.DASHBOARD: 1,
|
||||
QuerySource.SQL_LAB: 2,
|
||||
QuerySource.METADATA: 3,
|
||||
}
|
||||
|
||||
for source, expected_value in expected_sources.items():
|
||||
assert source.value == expected_value
|
||||
|
||||
Reference in New Issue
Block a user