mirror of
https://github.com/apache/superset.git
synced 2026-04-30 13:34:20 +00:00
Compare commits
7 Commits
codex/fix-
...
standardiz
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dfe73d19c | ||
|
|
bbda5e2008 | ||
|
|
53999c12dd | ||
|
|
f554036d29 | ||
|
|
33e7932491 | ||
|
|
92b02d993b | ||
|
|
72ba972e42 |
@@ -49,6 +49,7 @@ from flask_babel import gettext as __, lazy_gettext as _
|
|||||||
from marshmallow import fields, Schema
|
from marshmallow import fields, Schema
|
||||||
from marshmallow.validate import Range
|
from marshmallow.validate import Range
|
||||||
from sqlalchemy import column, select, types
|
from sqlalchemy import column, select, types
|
||||||
|
from sqlalchemy.engine import Result
|
||||||
from sqlalchemy.engine.base import Engine
|
from sqlalchemy.engine.base import Engine
|
||||||
from sqlalchemy.engine.interfaces import Compiled, Dialect
|
from sqlalchemy.engine.interfaces import Compiled, Dialect
|
||||||
from sqlalchemy.engine.reflection import Inspector
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
@@ -1670,14 +1671,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def estimate_statement_cost(
|
def estimate_statement_cost(
|
||||||
cls, database: Database, statement: str, cursor: Any
|
cls, database: Database, statement: str
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate a SQL query that estimates the cost of a given statement.
|
Generate a SQL query that estimates the cost of a given statement.
|
||||||
|
|
||||||
:param database: A Database object
|
:param database: A Database object
|
||||||
:param statement: A single SQL statement
|
:param statement: A single SQL statement
|
||||||
:param cursor: Cursor instance
|
|
||||||
:return: Dictionary with different costs
|
:return: Dictionary with different costs
|
||||||
"""
|
"""
|
||||||
raise Exception( # pylint: disable=broad-exception-raised
|
raise Exception( # pylint: disable=broad-exception-raised
|
||||||
@@ -1738,20 +1738,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
|
|
||||||
parsed_script = SQLScript(sql, engine=cls.engine)
|
parsed_script = SQLScript(sql, engine=cls.engine)
|
||||||
|
|
||||||
with database.get_raw_connection(
|
return [
|
||||||
catalog=catalog,
|
cls.estimate_statement_cost(
|
||||||
schema=schema,
|
database,
|
||||||
source=source,
|
cls.process_statement(statement, database),
|
||||||
) as conn:
|
)
|
||||||
cursor = conn.cursor()
|
for statement in parsed_script.statements
|
||||||
return [
|
]
|
||||||
cls.estimate_statement_cost(
|
|
||||||
database,
|
|
||||||
cls.process_statement(statement, database),
|
|
||||||
cursor,
|
|
||||||
)
|
|
||||||
for statement in parsed_script.statements
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def impersonate_user(
|
def impersonate_user(
|
||||||
@@ -1854,6 +1847,42 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
cls.start_oauth2_dance(database)
|
cls.start_oauth2_dance(database)
|
||||||
raise cls.get_dbapi_mapped_exception(ex) from ex
|
raise cls.get_dbapi_mapped_exception(ex) from ex
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute_metadata_query(
|
||||||
|
cls,
|
||||||
|
database: Database,
|
||||||
|
query: str,
|
||||||
|
catalog: str | None = None,
|
||||||
|
schema: str | None = None,
|
||||||
|
) -> Result:
|
||||||
|
"""
|
||||||
|
Standardized method for executing metadata queries.
|
||||||
|
|
||||||
|
This method provides a unified interface for all metadata query operations
|
||||||
|
across different database engines using SQLAlchemy connections.
|
||||||
|
|
||||||
|
For single-row results, add "LIMIT 1" to your query rather than using
|
||||||
|
separate fetch parameters.
|
||||||
|
|
||||||
|
:param database: Database instance
|
||||||
|
:param query: SQL query to execute for metadata
|
||||||
|
:param catalog: Optional catalog/database name
|
||||||
|
:param schema: Optional schema name
|
||||||
|
:return: SQLAlchemy Result object with methods like:
|
||||||
|
- result.fetchall() -> list[Row]: Get all rows
|
||||||
|
- result.fetchone() -> Row | None: Get single row
|
||||||
|
- result.scalar() -> Any: Get single value
|
||||||
|
- result.mappings() -> mappings for dict-like access
|
||||||
|
"""
|
||||||
|
with cls.get_engine(
|
||||||
|
database,
|
||||||
|
catalog=catalog,
|
||||||
|
schema=schema,
|
||||||
|
source=utils.QuerySource.METADATA,
|
||||||
|
) as engine:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
return conn.execute(text(query))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def needs_oauth2(cls, ex: Exception) -> bool:
|
def needs_oauth2(cls, ex: Exception) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -331,9 +331,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build the query
|
# Build the query
|
||||||
query = select(
|
query = (
|
||||||
func.max(partitions_table.c.partition_id).label("max_partition_id")
|
select(func.max(partitions_table.c.partition_id).label("max_partition_id"))
|
||||||
).where(partitions_table.c.table_name == table.table)
|
.where(partitions_table.c.table_name == table.table)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
# Compile to BigQuery SQL
|
# Compile to BigQuery SQL
|
||||||
compiled_query = query.compile(
|
compiled_query = query.compile(
|
||||||
@@ -342,15 +344,13 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Run the query and handle result
|
# Run the query and handle result
|
||||||
with database.get_raw_connection(
|
result = cls.execute_metadata_query(
|
||||||
|
database,
|
||||||
|
str(compiled_query),
|
||||||
catalog=table.catalog,
|
catalog=table.catalog,
|
||||||
schema=table.schema,
|
schema=table.schema,
|
||||||
) as conn:
|
)
|
||||||
cursor = conn.cursor()
|
return result.scalar()
|
||||||
cursor.execute(str(compiled_query))
|
|
||||||
if row := cursor.fetchone():
|
|
||||||
return row[0]
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_time_partition_column(
|
def get_time_partition_column(
|
||||||
|
|||||||
@@ -503,12 +503,13 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
|
|||||||
if default_catalog := connect_args.get("catalog"):
|
if default_catalog := connect_args.get("catalog"):
|
||||||
return default_catalog
|
return default_catalog
|
||||||
|
|
||||||
with database.get_sqla_engine() as engine:
|
result = cls.execute_metadata_query(database, "SHOW CATALOGS")
|
||||||
catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")}
|
catalogs = {catalog for (catalog,) in result}
|
||||||
if len(catalogs) == 1:
|
if len(catalogs) == 1:
|
||||||
return catalogs.pop()
|
return catalogs.pop()
|
||||||
|
|
||||||
return engine.execute("SELECT current_catalog()").scalar()
|
result = cls.execute_metadata_query(database, "SELECT current_catalog()")
|
||||||
|
return result.scalar()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_prequeries(
|
def get_prequeries(
|
||||||
@@ -532,7 +533,8 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
|
|||||||
database: Database,
|
database: Database,
|
||||||
inspector: Inspector,
|
inspector: Inspector,
|
||||||
) -> set[str]:
|
) -> set[str]:
|
||||||
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
|
results = cls.execute_metadata_query(database, "SHOW CATALOGS")
|
||||||
|
return {catalog for (catalog,) in results}
|
||||||
|
|
||||||
|
|
||||||
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
|
class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
|
||||||
@@ -624,7 +626,8 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
|
|||||||
database: Database,
|
database: Database,
|
||||||
inspector: Inspector,
|
inspector: Inspector,
|
||||||
) -> set[str]:
|
) -> set[str]:
|
||||||
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
|
results = cls.execute_metadata_query(database, "SHOW CATALOGS")
|
||||||
|
return {catalog for (catalog,) in results}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def adjust_engine_params(
|
def adjust_engine_params(
|
||||||
|
|||||||
@@ -297,8 +297,8 @@ class DorisEngineSpec(MySQLEngineSpec):
|
|||||||
CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
|
CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
|
||||||
We need to extract just the CatalogName column.
|
We need to extract just the CatalogName column.
|
||||||
"""
|
"""
|
||||||
result = inspector.bind.execute("SHOW CATALOGS")
|
results = cls.execute_metadata_query(database, "SHOW CATALOGS")
|
||||||
return {row.CatalogName for row in result}
|
return {row.CatalogName for row in results}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_schema_from_engine_params(
|
def get_schema_from_engine_params(
|
||||||
|
|||||||
@@ -388,9 +388,8 @@ class MotherDuckEngineSpec(DuckDBEngineSpec):
|
|||||||
database: Database,
|
database: Database,
|
||||||
inspector: Inspector,
|
inspector: Inspector,
|
||||||
) -> set[str]:
|
) -> set[str]:
|
||||||
return {
|
results = cls.execute_metadata_query(
|
||||||
catalog
|
database,
|
||||||
for (catalog,) in inspector.bind.execute(
|
"SELECT alias FROM MD_ALL_DATABASES() WHERE is_attached;",
|
||||||
"SELECT alias FROM MD_ALL_DATABASES() WHERE is_attached;"
|
)
|
||||||
)
|
return {catalog for (catalog,) in results}
|
||||||
}
|
|
||||||
|
|||||||
@@ -154,15 +154,16 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
|
|||||||
database: Database,
|
database: Database,
|
||||||
table: Table,
|
table: Table,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
with database.get_raw_connection(
|
result = cls.execute_metadata_query(
|
||||||
|
database,
|
||||||
|
f'SELECT GET_METADATA("{table.table}") LIMIT 1',
|
||||||
catalog=table.catalog,
|
catalog=table.catalog,
|
||||||
schema=table.schema,
|
schema=table.schema,
|
||||||
) as conn:
|
)
|
||||||
cursor = conn.cursor()
|
results_list = result.fetchall()
|
||||||
cursor.execute(f'SELECT GET_METADATA("{table.table}")')
|
results = results_list[0][0] if results_list else None
|
||||||
results = cursor.fetchone()[0]
|
|
||||||
try:
|
try:
|
||||||
metadata = json.loads(results)
|
metadata = json.loads(results) if results else {}
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
|
|||||||
@@ -614,8 +614,9 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||||||
if schema:
|
if schema:
|
||||||
sql += f" IN `{schema}`"
|
sql += f" IN `{schema}`"
|
||||||
|
|
||||||
with database.get_raw_connection(schema=schema) as conn:
|
result = cls.execute_metadata_query(
|
||||||
cursor = conn.cursor()
|
database,
|
||||||
cursor.execute(sql)
|
sql,
|
||||||
results = cursor.fetchall()
|
schema=schema,
|
||||||
return {row[0] for row in results}
|
)
|
||||||
|
return {row[0] for row in result}
|
||||||
|
|||||||
@@ -354,19 +354,18 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def estimate_statement_cost(
|
def estimate_statement_cost(
|
||||||
cls, database: Database, statement: str, cursor: Any
|
cls, database: Database, statement: str
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Run a SQL query that estimates the cost of a given statement.
|
Run a SQL query that estimates the cost of a given statement.
|
||||||
:param database: A Database object
|
:param database: A Database object
|
||||||
:param statement: A single SQL statement
|
:param statement: A single SQL statement
|
||||||
:param cursor: Cursor instance
|
:return: Cost estimate dictionary
|
||||||
:return: JSON response from Trino
|
|
||||||
"""
|
"""
|
||||||
sql = f"EXPLAIN {statement}"
|
sql = f"EXPLAIN {statement} LIMIT 1"
|
||||||
cursor.execute(sql)
|
results = cls.execute_metadata_query(database, sql)
|
||||||
|
rows = results.fetchall()
|
||||||
result = cursor.fetchone()[0]
|
result = rows[0][0] if rows else ""
|
||||||
match = re.search(r"cost=([\d\.]+)\.\.([\d\.]+)", result)
|
match = re.search(r"cost=([\d\.]+)\.\.([\d\.]+)", result)
|
||||||
if match:
|
if match:
|
||||||
return {
|
return {
|
||||||
@@ -393,15 +392,14 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
|||||||
|
|
||||||
In Postgres, a catalog is called a "database".
|
In Postgres, a catalog is called a "database".
|
||||||
"""
|
"""
|
||||||
return {
|
results = cls.execute_metadata_query(
|
||||||
catalog
|
database,
|
||||||
for (catalog,) in inspector.bind.execute(
|
"""
|
||||||
"""
|
|
||||||
SELECT datname FROM pg_database
|
SELECT datname FROM pg_database
|
||||||
WHERE datistemplate = false;
|
WHERE datistemplate = false;
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
}
|
return {catalog for (catalog,) in results}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_table_names(
|
def get_table_names(
|
||||||
|
|||||||
@@ -321,7 +321,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
Get all catalogs.
|
Get all catalogs.
|
||||||
"""
|
"""
|
||||||
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
|
result = cls.execute_metadata_query(database, "SHOW CATALOGS")
|
||||||
|
return {catalog for (catalog,) in result}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def adjust_engine_params(
|
def adjust_engine_params(
|
||||||
@@ -373,17 +374,16 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def estimate_statement_cost(
|
def estimate_statement_cost(
|
||||||
cls, database: Database, statement: str, cursor: Any
|
cls, database: Database, statement: str
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Run a SQL query that estimates the cost of a given statement.
|
Run a SQL query that estimates the cost of a given statement.
|
||||||
:param database: A Database object
|
:param database: A Database object
|
||||||
:param statement: A single SQL statement
|
:param statement: A single SQL statement
|
||||||
:param cursor: Cursor instance
|
|
||||||
:return: JSON response from Trino
|
:return: JSON response from Trino
|
||||||
"""
|
"""
|
||||||
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}"
|
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement} LIMIT 1"
|
||||||
cursor.execute(sql)
|
results = cls.execute_metadata_query(database, sql)
|
||||||
|
|
||||||
# the output from Trino is a single column and a single row containing
|
# the output from Trino is a single column and a single row containing
|
||||||
# JSON:
|
# JSON:
|
||||||
@@ -398,7 +398,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
|||||||
# "networkCost" : 3.41425774958E11
|
# "networkCost" : 3.41425774958E11
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
result = json.loads(cursor.fetchone()[0])
|
result = json.loads(results[0][0]) if results else {}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1037,7 +1037,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||||||
AND table_type = 'VIEW'
|
AND table_type = 'VIEW'
|
||||||
"""
|
"""
|
||||||
).strip()
|
).strip()
|
||||||
params = {"schema": schema}
|
results = inspector.bind.execute(sql, {"schema": schema}).fetchall()
|
||||||
else:
|
else:
|
||||||
sql = dedent(
|
sql = dedent(
|
||||||
"""
|
"""
|
||||||
@@ -1045,13 +1045,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||||||
WHERE table_type = 'VIEW'
|
WHERE table_type = 'VIEW'
|
||||||
"""
|
"""
|
||||||
).strip()
|
).strip()
|
||||||
params = {}
|
results = inspector.bind.execute(sql).fetchall()
|
||||||
|
|
||||||
with database.get_raw_connection(schema=schema) as conn:
|
return {row[0] for row in results}
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(sql, params)
|
|
||||||
results = cursor.fetchall()
|
|
||||||
return {row[0] for row in results}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _is_column_name_quoted(cls, column_name: str) -> bool:
|
def _is_column_name_quoted(cls, column_name: str) -> bool:
|
||||||
@@ -1299,16 +1295,12 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||||||
# pylint: disable=import-outside-toplevel
|
# pylint: disable=import-outside-toplevel
|
||||||
from pyhive.exc import DatabaseError
|
from pyhive.exc import DatabaseError
|
||||||
|
|
||||||
with database.get_raw_connection(schema=schema) as conn:
|
sql = f"SHOW CREATE VIEW {schema}.{table} LIMIT 1"
|
||||||
cursor = conn.cursor()
|
try:
|
||||||
sql = f"SHOW CREATE VIEW {schema}.{table}"
|
results = cls.execute_metadata_query(database, sql, schema=schema)
|
||||||
try:
|
return results[0][0] if results else None
|
||||||
cls.execute(cursor, sql, database)
|
except DatabaseError: # not a VIEW
|
||||||
rows = cls.fetch_data(cursor, 1)
|
return None
|
||||||
|
|
||||||
return rows[0][0]
|
|
||||||
except DatabaseError: # not a VIEW
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tracking_url(cls, cursor: Cursor) -> str | None:
|
def get_tracking_url(cls, cursor: Cursor) -> str | None:
|
||||||
|
|||||||
@@ -209,12 +209,11 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||||||
|
|
||||||
In Snowflake, a catalog is called a "database".
|
In Snowflake, a catalog is called a "database".
|
||||||
"""
|
"""
|
||||||
return {
|
results = cls.execute_metadata_query(
|
||||||
catalog
|
database,
|
||||||
for (catalog,) in inspector.bind.execute(
|
"SELECT DATABASE_NAME from information_schema.databases",
|
||||||
"SELECT DATABASE_NAME from information_schema.databases"
|
)
|
||||||
)
|
return {catalog for (catalog,) in results}
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def epoch_to_dttm(cls) -> str:
|
def epoch_to_dttm(cls) -> str:
|
||||||
|
|||||||
@@ -320,6 +320,7 @@ class QuerySource(Enum):
|
|||||||
CHART = 0
|
CHART = 0
|
||||||
DASHBOARD = 1
|
DASHBOARD = 1
|
||||||
SQL_LAB = 2
|
SQL_LAB = 2
|
||||||
|
METADATA = 3
|
||||||
|
|
||||||
|
|
||||||
class QueryStatus(StrEnum):
|
class QueryStatus(StrEnum):
|
||||||
|
|||||||
@@ -152,13 +152,22 @@ class TestPostgresDbEngineSpec(SupersetTestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
database = mock.Mock()
|
database = mock.Mock()
|
||||||
cursor = mock.Mock()
|
|
||||||
cursor.fetchone.return_value = (
|
|
||||||
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
|
|
||||||
)
|
|
||||||
sql = "SELECT * FROM birth_names"
|
sql = "SELECT * FROM birth_names"
|
||||||
results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
|
|
||||||
|
# Mock the execute_metadata_query method to return expected results
|
||||||
|
with mock.patch.object(
|
||||||
|
PostgresEngineSpec, "execute_metadata_query"
|
||||||
|
) as mock_execute:
|
||||||
|
mock_result = mock.Mock()
|
||||||
|
expected_results = [
|
||||||
|
("Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",)
|
||||||
|
]
|
||||||
|
mock_result.fetchall.return_value = expected_results
|
||||||
|
mock_execute.return_value = mock_result
|
||||||
|
results = PostgresEngineSpec.estimate_statement_cost(database, sql)
|
||||||
|
|
||||||
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
|
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}
|
||||||
|
mock_execute.assert_called_once_with(database, f"EXPLAIN {sql} LIMIT 1")
|
||||||
|
|
||||||
def test_estimate_statement_invalid_syntax(self):
|
def test_estimate_statement_invalid_syntax(self):
|
||||||
"""
|
"""
|
||||||
@@ -167,17 +176,21 @@ class TestPostgresDbEngineSpec(SupersetTestCase):
|
|||||||
from psycopg2 import errors
|
from psycopg2 import errors
|
||||||
|
|
||||||
database = mock.Mock()
|
database = mock.Mock()
|
||||||
cursor = mock.Mock()
|
|
||||||
cursor.execute.side_effect = errors.SyntaxError(
|
|
||||||
"""
|
|
||||||
syntax error at or near "EXPLAIN"
|
|
||||||
LINE 1: EXPLAIN DROP TABLE birth_names
|
|
||||||
^
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
sql = "DROP TABLE birth_names"
|
sql = "DROP TABLE birth_names"
|
||||||
with self.assertRaises(errors.SyntaxError): # noqa: PT027
|
|
||||||
PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
|
# Mock the execute_metadata_query method to raise the expected exception
|
||||||
|
with mock.patch.object(
|
||||||
|
PostgresEngineSpec, "execute_metadata_query"
|
||||||
|
) as mock_execute:
|
||||||
|
mock_execute.side_effect = errors.SyntaxError(
|
||||||
|
"""
|
||||||
|
syntax error at or near "EXPLAIN"
|
||||||
|
LINE 1: EXPLAIN DROP TABLE birth_names
|
||||||
|
^
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with self.assertRaises(errors.SyntaxError): # noqa: PT027
|
||||||
|
PostgresEngineSpec.estimate_statement_cost(database, sql)
|
||||||
|
|
||||||
def test_query_cost_formatter_example_costs(self):
|
def test_query_cost_formatter_example_costs(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -931,26 +931,34 @@ class TestPrestoDbEngineSpec(SupersetTestCase):
|
|||||||
|
|
||||||
def test_estimate_statement_cost(self):
|
def test_estimate_statement_cost(self):
|
||||||
mock_database = mock.MagicMock()
|
mock_database = mock.MagicMock()
|
||||||
mock_cursor = mock.MagicMock()
|
|
||||||
estimate_json = {"a": "b"}
|
estimate_json = {"a": "b"}
|
||||||
mock_cursor.fetchone.return_value = [
|
sql = "SELECT * FROM brth_names"
|
||||||
'{"a": "b"}',
|
|
||||||
]
|
# Mock the execute_metadata_query method to return expected JSON results
|
||||||
result = PrestoEngineSpec.estimate_statement_cost(
|
with mock.patch.object(
|
||||||
mock_database,
|
PrestoEngineSpec, "execute_metadata_query"
|
||||||
"SELECT * FROM brth_names",
|
) as mock_execute:
|
||||||
mock_cursor,
|
mock_result = mock.Mock()
|
||||||
)
|
mock_result.scalar.return_value = '{"a": "b"}'
|
||||||
|
mock_execute.return_value = mock_result
|
||||||
|
result = PrestoEngineSpec.estimate_statement_cost(mock_database, sql)
|
||||||
|
|
||||||
assert result == estimate_json
|
assert result == estimate_json
|
||||||
|
mock_execute.assert_called_once_with(
|
||||||
|
mock_database, f"EXPLAIN (TYPE IO, FORMAT JSON) {sql} LIMIT 1"
|
||||||
|
)
|
||||||
|
|
||||||
def test_estimate_statement_cost_invalid_syntax(self):
|
def test_estimate_statement_cost_invalid_syntax(self):
|
||||||
mock_database = mock.MagicMock()
|
mock_database = mock.MagicMock()
|
||||||
mock_cursor = mock.MagicMock()
|
sql = "DROP TABLE brth_names"
|
||||||
mock_cursor.execute.side_effect = Exception()
|
|
||||||
with self.assertRaises(Exception): # noqa: B017, PT027
|
# Mock the execute_metadata_query method to raise an exception
|
||||||
PrestoEngineSpec.estimate_statement_cost(
|
with mock.patch.object(
|
||||||
mock_database, "DROP TABLE brth_names", mock_cursor
|
PrestoEngineSpec, "execute_metadata_query"
|
||||||
)
|
) as mock_execute:
|
||||||
|
mock_execute.side_effect = Exception("Invalid syntax")
|
||||||
|
with self.assertRaises(Exception): # noqa: B017, PT027
|
||||||
|
PrestoEngineSpec.estimate_statement_cost(mock_database, sql)
|
||||||
|
|
||||||
def test_get_create_view(self):
|
def test_get_create_view(self):
|
||||||
mock_execute = mock.MagicMock()
|
mock_execute = mock.MagicMock()
|
||||||
@@ -1207,10 +1215,10 @@ def test_get_catalog_names(app_context: AppContext) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
with database.get_inspector() as inspector:
|
with database.get_inspector() as inspector:
|
||||||
assert PrestoEngineSpec.get_catalog_names(database, inspector) == [
|
assert PrestoEngineSpec.get_catalog_names(database, inspector) == {
|
||||||
"jmx",
|
"jmx",
|
||||||
"memory",
|
"memory",
|
||||||
"system",
|
"system",
|
||||||
"tpcds",
|
"tpcds",
|
||||||
"tpch",
|
"tpch",
|
||||||
]
|
}
|
||||||
|
|||||||
@@ -44,10 +44,10 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock:
|
|||||||
database.database_name = "my_db"
|
database.database_name = "my_db"
|
||||||
database.db_engine_spec.__name__ = "test_engine"
|
database.db_engine_spec.__name__ = "test_engine"
|
||||||
database.db_engine_spec.supports_catalog = True
|
database.db_engine_spec.supports_catalog = True
|
||||||
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
|
database.get_all_catalog_names.return_value = {"catalog1", "catalog2"}
|
||||||
database.get_all_schema_names.side_effect = [
|
database.get_all_schema_names.side_effect = [
|
||||||
["schema1", "schema2"],
|
{"schema1", "schema2"},
|
||||||
["schema3", "schema4"],
|
{"schema3", "schema4"},
|
||||||
]
|
]
|
||||||
database.get_default_catalog.return_value = "catalog2"
|
database.get_default_catalog.return_value = "catalog2"
|
||||||
|
|
||||||
@@ -63,7 +63,7 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock:
|
|||||||
database.database_name = "my_db"
|
database.database_name = "my_db"
|
||||||
database.db_engine_spec.__name__ = "test_engine"
|
database.db_engine_spec.__name__ = "test_engine"
|
||||||
database.db_engine_spec.supports_catalog = False
|
database.db_engine_spec.supports_catalog = False
|
||||||
database.get_all_schema_names.return_value = ["schema1", "schema2"]
|
database.get_all_schema_names.return_value = {"schema1", "schema2"}
|
||||||
database.is_oauth2_enabled.return_value = False
|
database.is_oauth2_enabled.return_value = False
|
||||||
database.db_engine_spec.needs_oauth2.return_value = False
|
database.db_engine_spec.needs_oauth2.return_value = False
|
||||||
|
|
||||||
|
|||||||
@@ -69,23 +69,24 @@ def test_sync_permissions_command_sync_mode(
|
|||||||
add_pvm_mock.assert_has_calls(
|
add_pvm_mock.assert_has_calls(
|
||||||
[
|
[
|
||||||
mocker.call(
|
mocker.call(
|
||||||
db.session, security_manager, "catalog_access", "[my_db].[catalog2]"
|
db.session, security_manager, "catalog_access", "[my_db].[catalog1]"
|
||||||
),
|
),
|
||||||
mocker.call(
|
mocker.call(
|
||||||
db.session,
|
db.session,
|
||||||
security_manager,
|
security_manager,
|
||||||
"schema_access",
|
"schema_access",
|
||||||
"[my_db].[catalog2].[schema3]",
|
"[my_db].[catalog1].[schema3]",
|
||||||
),
|
),
|
||||||
mocker.call(
|
mocker.call(
|
||||||
db.session,
|
db.session,
|
||||||
security_manager,
|
security_manager,
|
||||||
"schema_access",
|
"schema_access",
|
||||||
"[my_db].[catalog2].[schema4]",
|
"[my_db].[catalog1].[schema4]",
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
|
any_order=True,
|
||||||
)
|
)
|
||||||
mock_refresh_schemas.assert_called_once_with("catalog1", ["schema1", "schema2"])
|
mock_refresh_schemas.assert_called_once_with("catalog2", {"schema1", "schema2"})
|
||||||
mock_rename_db_perm.assert_not_called()
|
mock_rename_db_perm.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@@ -246,7 +247,7 @@ def test_sync_permissions_command_get_catalogs(database_with_catalog: MagicMock)
|
|||||||
Test the ``_get_catalog_names`` method.
|
Test the ``_get_catalog_names`` method.
|
||||||
"""
|
"""
|
||||||
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
|
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
|
||||||
assert cmmd._get_catalog_names() == ["catalog1", "catalog2"]
|
assert cmmd._get_catalog_names() == {"catalog1", "catalog2"}
|
||||||
|
|
||||||
|
|
||||||
def test_sync_permissions_command_get_default_catalog(database_with_catalog: MagicMock):
|
def test_sync_permissions_command_get_default_catalog(database_with_catalog: MagicMock):
|
||||||
@@ -263,7 +264,7 @@ def test_sync_permissions_command_get_default_catalog(database_with_catalog: Mag
|
|||||||
|
|
||||||
database_with_catalog.allow_multi_catalog = True
|
database_with_catalog.allow_multi_catalog = True
|
||||||
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
|
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
|
||||||
assert cmmd._get_catalog_names() == ["catalog1", "catalog2"]
|
assert cmmd._get_catalog_names() == {"catalog1", "catalog2"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -295,8 +296,8 @@ def test_sync_permissions_command_get_schemas(database_with_catalog: MagicMock):
|
|||||||
Test the ``_get_schema_names`` method.
|
Test the ``_get_schema_names`` method.
|
||||||
"""
|
"""
|
||||||
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
|
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
|
||||||
assert cmmd._get_schema_names("catalog1") == ["schema1", "schema2"]
|
assert cmmd._get_schema_names("catalog1") == {"schema1", "schema2"}
|
||||||
assert cmmd._get_schema_names("catalog2") == ["schema3", "schema4"]
|
assert cmmd._get_schema_names("catalog2") == {"schema3", "schema4"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
195
tests/unit_tests/db_engine_specs/test_base_metadata.py
Normal file
195
tests/unit_tests/db_engine_specs/test_base_metadata.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
from superset.db_engine_specs.base import BaseEngineSpec
|
||||||
|
from superset.utils.core import QuerySource
|
||||||
|
from tests.integration_tests.base_tests import SupersetTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseEngineSpecMetadata(SupersetTestCase):
|
||||||
|
@mock.patch.object(BaseEngineSpec, "get_engine")
|
||||||
|
def test_execute_metadata_query_basic(self, mock_get_engine):
|
||||||
|
"""Test basic metadata query execution"""
|
||||||
|
database = mock.Mock()
|
||||||
|
mock_engine = mock.Mock()
|
||||||
|
mock_connection = mock.Mock()
|
||||||
|
mock_result = mock.Mock()
|
||||||
|
|
||||||
|
# Setup mock chain
|
||||||
|
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
|
||||||
|
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
|
||||||
|
|
||||||
|
mock_engine.connect.return_value.__enter__ = mock.Mock(
|
||||||
|
return_value=mock_connection
|
||||||
|
)
|
||||||
|
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
|
||||||
|
|
||||||
|
expected_results = [("catalog1",), ("catalog2",)]
|
||||||
|
mock_result.fetchall.return_value = expected_results
|
||||||
|
mock_result.__iter__ = lambda self: iter(expected_results)
|
||||||
|
mock_connection.execute.return_value = mock_result
|
||||||
|
|
||||||
|
# Execute the method
|
||||||
|
query = "SHOW CATALOGS"
|
||||||
|
results = BaseEngineSpec.execute_metadata_query(database, query)
|
||||||
|
|
||||||
|
# Verify results - now we get the Result object itself
|
||||||
|
assert results == mock_result
|
||||||
|
# Test that we can iterate over it like the real usage
|
||||||
|
assert list(results) == expected_results
|
||||||
|
|
||||||
|
# Verify call chain
|
||||||
|
mock_get_engine.assert_called_once_with(
|
||||||
|
database, catalog=None, schema=None, source=QuerySource.METADATA
|
||||||
|
)
|
||||||
|
mock_connection.execute.assert_called_once()
|
||||||
|
executed_query = mock_connection.execute.call_args[0][0]
|
||||||
|
assert str(executed_query) == query
|
||||||
|
|
||||||
|
@mock.patch.object(BaseEngineSpec, "get_engine")
|
||||||
|
def test_execute_metadata_query_with_catalog_schema(self, mock_get_engine):
|
||||||
|
"""Test metadata query with catalog and schema parameters"""
|
||||||
|
database = mock.Mock()
|
||||||
|
mock_engine = mock.Mock()
|
||||||
|
mock_connection = mock.Mock()
|
||||||
|
mock_result = mock.Mock()
|
||||||
|
|
||||||
|
# Setup mock chain
|
||||||
|
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
|
||||||
|
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
|
||||||
|
|
||||||
|
mock_engine.connect.return_value.__enter__ = mock.Mock(
|
||||||
|
return_value=mock_connection
|
||||||
|
)
|
||||||
|
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
|
||||||
|
|
||||||
|
expected_results = [("table1",), ("table2",)]
|
||||||
|
mock_result.fetchall.return_value = expected_results
|
||||||
|
mock_result.__iter__ = lambda self: iter(expected_results)
|
||||||
|
mock_connection.execute.return_value = mock_result
|
||||||
|
|
||||||
|
# Execute with catalog and schema
|
||||||
|
query = "SHOW TABLES"
|
||||||
|
catalog = "my_catalog"
|
||||||
|
schema = "my_schema"
|
||||||
|
results = BaseEngineSpec.execute_metadata_query(
|
||||||
|
database, query, catalog=catalog, schema=schema
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify results - now we get the Result object itself
|
||||||
|
assert results == mock_result
|
||||||
|
# Test that we can iterate over it like the real usage
|
||||||
|
assert list(results) == expected_results
|
||||||
|
|
||||||
|
# Verify catalog and schema were passed to get_engine
|
||||||
|
mock_get_engine.assert_called_once_with(
|
||||||
|
database, catalog=catalog, schema=schema, source=QuerySource.METADATA
|
||||||
|
)
|
||||||
|
|
||||||
|
@mock.patch.object(BaseEngineSpec, "get_engine")
|
||||||
|
def test_execute_metadata_query_uses_correct_query_source(self, mock_get_engine):
|
||||||
|
"""Test that QuerySource.METADATA is used correctly"""
|
||||||
|
database = mock.Mock()
|
||||||
|
mock_engine = mock.Mock()
|
||||||
|
mock_connection = mock.Mock()
|
||||||
|
mock_result = mock.Mock()
|
||||||
|
|
||||||
|
# Setup mock chain
|
||||||
|
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
|
||||||
|
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
|
||||||
|
|
||||||
|
mock_engine.connect.return_value.__enter__ = mock.Mock(
|
||||||
|
return_value=mock_connection
|
||||||
|
)
|
||||||
|
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
|
||||||
|
|
||||||
|
empty_results = []
|
||||||
|
mock_result.fetchall.return_value = empty_results
|
||||||
|
mock_result.__iter__ = lambda self: iter(empty_results)
|
||||||
|
mock_connection.execute.return_value = mock_result
|
||||||
|
|
||||||
|
# Execute the method
|
||||||
|
BaseEngineSpec.execute_metadata_query(database, "SELECT 1")
|
||||||
|
|
||||||
|
# Verify QuerySource.METADATA was used
|
||||||
|
mock_get_engine.assert_called_once_with(
|
||||||
|
database, catalog=None, schema=None, source=QuerySource.METADATA
|
||||||
|
)
|
||||||
|
|
||||||
|
@mock.patch.object(BaseEngineSpec, "get_engine")
|
||||||
|
def test_execute_metadata_query_handles_sql_error(self, mock_get_engine):
|
||||||
|
"""Test proper exception handling for SQL errors"""
|
||||||
|
database = mock.Mock()
|
||||||
|
mock_engine = mock.Mock()
|
||||||
|
mock_connection = mock.Mock()
|
||||||
|
|
||||||
|
# Setup mock chain
|
||||||
|
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
|
||||||
|
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
|
||||||
|
|
||||||
|
mock_engine.connect.return_value.__enter__ = mock.Mock(
|
||||||
|
return_value=mock_connection
|
||||||
|
)
|
||||||
|
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
|
||||||
|
|
||||||
|
# Mock connection to raise SQLAlchemyError
|
||||||
|
mock_connection.execute.side_effect = SQLAlchemyError("Database error")
|
||||||
|
|
||||||
|
# Execute and verify exception is propagated
|
||||||
|
with pytest.raises(SQLAlchemyError):
|
||||||
|
BaseEngineSpec.execute_metadata_query(database, "INVALID QUERY")
|
||||||
|
|
||||||
|
@mock.patch.object(BaseEngineSpec, "get_engine")
|
||||||
|
def test_execute_metadata_query_empty_results(self, mock_get_engine):
|
||||||
|
"""Test handling of queries that return no results"""
|
||||||
|
database = mock.Mock()
|
||||||
|
mock_engine = mock.Mock()
|
||||||
|
mock_connection = mock.Mock()
|
||||||
|
mock_result = mock.Mock()
|
||||||
|
|
||||||
|
# Setup mock chain
|
||||||
|
mock_get_engine.return_value.__enter__ = mock.Mock(return_value=mock_engine)
|
||||||
|
mock_get_engine.return_value.__exit__ = mock.Mock(return_value=None)
|
||||||
|
|
||||||
|
mock_engine.connect.return_value.__enter__ = mock.Mock(
|
||||||
|
return_value=mock_connection
|
||||||
|
)
|
||||||
|
mock_engine.connect.return_value.__exit__ = mock.Mock(return_value=None)
|
||||||
|
|
||||||
|
# Empty results
|
||||||
|
empty_results = []
|
||||||
|
mock_result.fetchall.return_value = empty_results
|
||||||
|
mock_result.__iter__ = lambda self: iter(empty_results)
|
||||||
|
mock_connection.execute.return_value = mock_result
|
||||||
|
|
||||||
|
# Execute the method
|
||||||
|
results = BaseEngineSpec.execute_metadata_query(database, "SELECT 1 WHERE 1=0")
|
||||||
|
|
||||||
|
# Verify empty results are handled correctly - now we get the Result object
|
||||||
|
assert results == mock_result
|
||||||
|
# Test that iterating over it gives empty results like the real usage
|
||||||
|
assert list(results) == []
|
||||||
|
|
||||||
|
def test_execute_metadata_query_query_source_enum_value(self):
|
||||||
|
"""Test QuerySource.METADATA enum has correct value"""
|
||||||
|
# This is a simple enum test that doesn't require complex mocking
|
||||||
|
assert QuerySource.METADATA.value == 3
|
||||||
|
assert QuerySource.METADATA.name == "METADATA"
|
||||||
@@ -1122,3 +1122,21 @@ def test_get_stacktrace():
|
|||||||
except Exception:
|
except Exception:
|
||||||
stacktrace = get_stacktrace()
|
stacktrace = get_stacktrace()
|
||||||
assert stacktrace is None
|
assert stacktrace is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_source_metadata_enum():
|
||||||
|
"""Test that QuerySource.METADATA enum value exists and has correct value"""
|
||||||
|
assert hasattr(QuerySource, "METADATA")
|
||||||
|
assert QuerySource.METADATA.value == 3
|
||||||
|
assert QuerySource.METADATA.name == "METADATA"
|
||||||
|
|
||||||
|
# Verify all expected QuerySource values
|
||||||
|
expected_sources = {
|
||||||
|
QuerySource.CHART: 0,
|
||||||
|
QuerySource.DASHBOARD: 1,
|
||||||
|
QuerySource.SQL_LAB: 2,
|
||||||
|
QuerySource.METADATA: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
for source, expected_value in expected_sources.items():
|
||||||
|
assert source.value == expected_value
|
||||||
|
|||||||
Reference in New Issue
Block a user