diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.test.js b/superset-frontend/src/SqlLab/actions/sqlLab.test.js
index ecf2c4d7e29..871b3ff6f6b 100644
--- a/superset-frontend/src/SqlLab/actions/sqlLab.test.js
+++ b/superset-frontend/src/SqlLab/actions/sqlLab.test.js
@@ -508,10 +508,11 @@ describe('async actions', () => {
fetchMock.delete(updateTableSchemaEndpoint, {});
fetchMock.post(updateTableSchemaEndpoint, JSON.stringify({ id: 1 }));
- const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table/*/*/';
+ const getTableMetadataEndpoint =
+ 'glob:**/api/v1/database/*/table_metadata/*';
fetchMock.get(getTableMetadataEndpoint, {});
const getExtraTableMetadataEndpoint =
- 'glob:**/api/v1/database/*/table_metadata/extra/';
+ 'glob:**/api/v1/database/*/table_metadata/extra/*';
fetchMock.get(getExtraTableMetadataEndpoint, {});
let isFeatureEnabledMock;
diff --git a/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx b/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx
index f8c94468bf7..b5003b16f7b 100644
--- a/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx
+++ b/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx
@@ -61,13 +61,13 @@ beforeEach(() => {
},
],
});
- fetchMock.get('glob:*/api/v1/database/*/table/*/*', {
+ fetchMock.get('glob:*/api/v1/database/*/table_metadata/*', {
status: 200,
body: {
columns: table.columns,
},
});
- fetchMock.get('glob:*/api/v1/database/*/table_metadata/extra/', {
+ fetchMock.get('glob:*/api/v1/database/*/table_metadata/extra/*', {
status: 200,
body: {},
});
diff --git a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx
index a2fe88020aa..1489f23a13a 100644
--- a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx
+++ b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx
@@ -47,9 +47,10 @@ jest.mock(
{column.name}
),
);
-const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table/*/*/';
+const getTableMetadataEndpoint =
+ /\/api\/v1\/database\/\d+\/table_metadata\/(?:\?.*)?$/;
const getExtraTableMetadataEndpoint =
- 'glob:**/api/v1/database/*/table_metadata/extra/*';
+ /\/api\/v1\/database\/\d+\/table_metadata\/extra\/(?:\?.*)?$/;
const updateTableSchemaEndpoint = 'glob:*/tableschemaview/*/expanded';
beforeEach(() => {
diff --git a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx
index ef5797fb309..b3f8aec8f99 100644
--- a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx
+++ b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx
@@ -74,7 +74,9 @@ const DatasetPanelWrapper = ({
const { dbId, tableName, schema } = props;
setLoading(true);
setHasColumns?.(false);
- const path = `/api/v1/database/${dbId}/table/${tableName}/${schema}/`;
+ const path = schema
+ ? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}`
+ : `/api/v1/database/${dbId}/table_metadata/?name=${tableName}`;
try {
const response = await SupersetClient.get({
endpoint: path,
diff --git a/superset-frontend/src/hooks/apiResources/tables.ts b/superset-frontend/src/hooks/apiResources/tables.ts
index 164fe0f0ab1..41be4c167c9 100644
--- a/superset-frontend/src/hooks/apiResources/tables.ts
+++ b/superset-frontend/src/hooks/apiResources/tables.ts
@@ -114,9 +114,9 @@ const tableApi = api.injectEndpoints({
}),
tableMetadata: builder.query({
query: ({ dbId, schema, table }) => ({
- endpoint: `/api/v1/database/${dbId}/table/${encodeURIComponent(
- table,
- )}/${encodeURIComponent(schema)}/`,
+ endpoint: schema
+ ? `/api/v1/database/${dbId}/table_metadata/?name=${table}&schema=${schema}`
+ : `/api/v1/database/${dbId}/table_metadata/?name=${table}`,
transformResponse: ({ json }: TableMetadataReponse) => json,
}),
}),
diff --git a/superset/commands/database/tables.py b/superset/commands/database/tables.py
index fa98bcbc7ec..055c0be9aea 100644
--- a/superset/commands/database/tables.py
+++ b/superset/commands/database/tables.py
@@ -51,6 +51,7 @@ class TablesDatabaseCommand(BaseCommand):
datasource_names=sorted(
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_table_names_in_schema(
+ catalog=None,
schema=self._schema_name,
force=self._force,
cache=self._model.table_cache_enabled,
@@ -65,6 +66,7 @@ class TablesDatabaseCommand(BaseCommand):
datasource_names=sorted(
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_view_names_in_schema(
+ catalog=None,
schema=self._schema_name,
force=self._force,
cache=self._model.table_cache_enabled,
diff --git a/superset/commands/database/validate_sql.py b/superset/commands/database/validate_sql.py
index 6ecc4f1626e..6a93a01473a 100644
--- a/superset/commands/database/validate_sql.py
+++ b/superset/commands/database/validate_sql.py
@@ -61,11 +61,12 @@ class ValidateSQLCommand(BaseCommand):
raise ValidatorSQLUnexpectedError()
sql = self._properties["sql"]
schema = self._properties.get("schema")
+ catalog = self._properties.get("catalog")
try:
timeout = current_app.config["SQLLAB_VALIDATION_TIMEOUT"]
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
with utils.timeout(seconds=timeout, error_message=timeout_msg):
- errors = self._validator.validate(sql, schema, self._model)
+ errors = self._validator.validate(sql, catalog, schema, self._model)
return [err.to_dict() for err in errors]
except Exception as ex:
logger.exception(ex)
diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py
index 16b87a567a5..dace92f911b 100644
--- a/superset/commands/dataset/create.py
+++ b/superset/commands/dataset/create.py
@@ -34,6 +34,7 @@ from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.exceptions import SupersetSecurityException
from superset.extensions import db, security_manager
+from superset.sql_parse import Table
logger = logging.getLogger(__name__)
@@ -61,12 +62,15 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
exceptions: list[ValidationError] = []
database_id = self._properties["database"]
table_name = self._properties["table_name"]
- schema = self._properties.get("schema", None)
- sql = self._properties.get("sql", None)
+ schema = self._properties.get("schema")
+ catalog = self._properties.get("catalog")
+ sql = self._properties.get("sql")
owner_ids: Optional[list[int]] = self._properties.get("owners")
+ table = Table(table_name, schema, catalog)
+
# Validate uniqueness
- if not DatasetDAO.validate_uniqueness(database_id, schema, table_name):
+ if not DatasetDAO.validate_uniqueness(database_id, table):
exceptions.append(DatasetExistsValidationError(table_name))
# Validate/Populate database
@@ -80,7 +84,7 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
if (
database
and not sql
- and not DatasetDAO.validate_table_exists(database, table_name, schema)
+ and not DatasetDAO.validate_table_exists(database, table)
):
exceptions.append(TableNotFoundValidationError(table_name))
@@ -89,6 +93,7 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
security_manager.raise_for_access(
database=database,
sql=sql,
+ catalog=catalog,
schema=schema,
)
except SupersetSecurityException as ex:
diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py
index 50bb916b078..0d2226f7242 100644
--- a/superset/commands/dataset/importers/v1/utils.py
+++ b/superset/commands/dataset/importers/v1/utils.py
@@ -32,6 +32,7 @@ from superset.commands.dataset.exceptions import DatasetForbiddenDataURI
from superset.commands.exceptions import ImportFailedError
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
+from superset.sql_parse import Table
from superset.utils.core import get_user
logger = logging.getLogger(__name__)
@@ -164,7 +165,9 @@ def import_dataset(
db.session.flush()
try:
- table_exists = dataset.database.has_table_by_name(dataset.table_name)
+ table_exists = dataset.database.has_table(
+ Table(dataset.table_name, dataset.schema),
+ )
except Exception: # pylint: disable=broad-except
# MySQL doesn't play nice with GSheets table names
logger.warning(
@@ -217,7 +220,10 @@ def load_data(data_uri: str, dataset: SqlaTable, database: Database) -> None:
)
else:
logger.warning("Loading data outside the import transaction")
- with database.get_sqla_engine() as engine:
+ with database.get_sqla_engine(
+ catalog=dataset.catalog,
+ schema=dataset.schema,
+ ) as engine:
df.to_sql(
dataset.table_name,
con=engine,
diff --git a/superset/commands/dataset/update.py b/superset/commands/dataset/update.py
index 5c0d87b230e..282c778eb43 100644
--- a/superset/commands/dataset/update.py
+++ b/superset/commands/dataset/update.py
@@ -41,6 +41,7 @@ from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.exceptions import SupersetSecurityException
+from superset.sql_parse import Table
logger = logging.getLogger(__name__)
@@ -90,9 +91,8 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
# Validate uniqueness
if not DatasetDAO.validate_update_uniqueness(
self._model.database_id,
- self._model.schema,
+ Table(table_name, self._model.schema, self._model.catalog),
self._model_id,
- table_name,
):
exceptions.append(DatasetExistsValidationError(table_name))
# Validate/Populate database not allowed to change
diff --git a/superset/commands/sql_lab/estimate.py b/superset/commands/sql_lab/estimate.py
index bf1d6c4fa57..d3198815662 100644
--- a/superset/commands/sql_lab/estimate.py
+++ b/superset/commands/sql_lab/estimate.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any
+from typing import Any, TypedDict
from flask_babel import gettext as __
@@ -27,7 +27,6 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException, SupersetTimeoutException
from superset.jinja_context import get_template_processor
from superset.models.core import Database
-from superset.sqllab.schemas import EstimateQueryCostSchema
from superset.utils import core as utils
config = app.config
@@ -37,18 +36,28 @@ stats_logger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)
+class EstimateQueryCostType(TypedDict):
+ database_id: int
+ sql: str
+ template_params: dict[str, Any]
+ catalog: str | None
+ schema: str | None
+
+
class QueryEstimationCommand(BaseCommand):
_database_id: int
_sql: str
_template_params: dict[str, Any]
_schema: str
_database: Database
+ _catalog: str | None
- def __init__(self, params: EstimateQueryCostSchema) -> None:
- self._database_id = params.get("database_id")
+ def __init__(self, params: EstimateQueryCostType) -> None:
+ self._database_id = params["database_id"]
self._sql = params.get("sql", "")
self._template_params = params.get("template_params", {})
- self._schema = params.get("schema", "")
+ self._schema = params.get("schema") or ""
+ self._catalog = params.get("catalog")
def validate(self) -> None:
self._database = db.session.query(Database).get(self._database_id)
@@ -77,7 +86,11 @@ class QueryEstimationCommand(BaseCommand):
try:
with utils.timeout(seconds=timeout, error_message=timeout_msg):
cost = self._database.db_engine_spec.estimate_query_cost(
- self._database, self._schema, sql, utils.QuerySource.SQL_LAB
+ self._database,
+ self._catalog,
+ self._schema,
+ sql,
+ utils.QuerySource.SQL_LAB,
)
except SupersetTimeoutException as ex:
logger.exception(ex)
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 339be9d177e..719d5af5888 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -49,7 +49,7 @@ from sqlalchemy import (
Integer,
or_,
String,
- Table,
+ Table as DBTable,
Text,
update,
)
@@ -108,7 +108,7 @@ from superset.models.helpers import (
validate_adhoc_subquery,
)
from superset.models.slice import Slice
-from superset.sql_parse import ParsedQuery, sanitize_clause
+from superset.sql_parse import ParsedQuery, sanitize_clause, Table
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
@@ -329,7 +329,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
"edit_url": self.url,
"id": self.id,
"uid": self.uid,
- "schema": self.schema,
+ "schema": self.schema or None,
"name": self.name,
"type": self.type,
"connection": self.connection,
@@ -383,7 +383,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
"datasource_name": self.datasource_name,
"table_name": self.datasource_name,
"type": self.type,
- "schema": self.schema,
+ "schema": self.schema or None,
"offset": self.offset,
"cache_timeout": self.cache_timeout,
"params": self.params,
@@ -1065,7 +1065,7 @@ class SqlMetric(AuditMixinNullable, ImportExportMixin, CertificationMixin, Model
return {s: getattr(self, s) for s in attrs}
-sqlatable_user = Table(
+sqlatable_user = DBTable(
"sqlatable_user",
metadata,
Column("id", Integer, primary_key=True),
@@ -1143,6 +1143,7 @@ class SqlaTable(
foreign_keys=[database_id],
)
schema = Column(String(255))
+ catalog = Column(String(256), nullable=True, default=None)
sql = Column(MediumText())
is_sqllab_view = Column(Boolean, default=False)
template_params = Column(Text)
@@ -1262,7 +1263,7 @@ class SqlaTable(
def get_schema_perm(self) -> str | None:
"""Returns schema permission if present, database one otherwise."""
- return security_manager.get_schema_perm(self.database, self.schema)
+ return security_manager.get_schema_perm(self.database, self.schema or None)
def get_perm(self) -> str:
"""
@@ -1319,8 +1320,7 @@ class SqlaTable(
return get_virtual_table_metadata(dataset=self)
return get_physical_table_metadata(
database=self.database,
- table_name=self.table_name,
- schema_name=self.schema,
+ table=Table(self.table_name, self.schema or None, self.catalog),
normalize_columns=self.normalize_columns,
)
@@ -1336,7 +1336,9 @@ class SqlaTable(
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
- self.table_name, schema=self.schema, show_cols=False, latest_partition=False
+ Table(self.table_name, self.schema or None, self.catalog),
+ show_cols=False,
+ latest_partition=False,
)
@property
@@ -1523,7 +1525,12 @@ class SqlaTable(
tbl, _ = self.get_from_clause(template_processor)
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry)
- col_desc = get_columns_description(self.database, self.schema, sql)
+ col_desc = get_columns_description(
+ self.database,
+ self.catalog,
+ self.schema or None,
+ sql,
+ )
if not col_desc:
raise SupersetGenericDBErrorException("Column not found")
is_dttm = col_desc[0]["is_dttm"] # type: ignore
@@ -1728,7 +1735,9 @@ class SqlaTable(
return df
try:
- df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
+ df = self.database.get_df(
+ sql, self.schema or None, mutator=assign_column_label
+ )
except (SupersetErrorException, SupersetErrorsException) as ex:
# SupersetError(s) exception should not be captured; instead, they should
# bubble up to the Flask error handler so they are returned as proper SIP-40
@@ -1762,7 +1771,13 @@ class SqlaTable(
)
def get_sqla_table_object(self) -> Table:
- return self.database.get_table(self.table_name, schema=self.schema)
+ return self.database.get_table(
+ Table(
+ self.table_name,
+ self.schema or None,
+ self.catalog,
+ )
+ )
def fetch_metadata(self, commit: bool = True) -> MetadataResult:
"""
@@ -1774,7 +1789,13 @@ class SqlaTable(
new_columns = self.external_metadata()
metrics = [
SqlMetric(**metric)
- for metric in self.database.get_metrics(self.table_name, self.schema)
+ for metric in self.database.get_metrics(
+ Table(
+ self.table_name,
+ self.schema or None,
+ self.catalog,
+ )
+ )
]
any_date_col = None
db_engine_spec = self.db_engine_spec
@@ -2021,7 +2042,7 @@ sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete)
sa.event.listen(SqlMetric, "after_update", SqlaTable.update_column)
sa.event.listen(TableColumn, "after_update", SqlaTable.update_column)
-RLSFilterRoles = Table(
+RLSFilterRoles = DBTable(
"rls_filter_roles",
metadata,
Column("id", Integer, primary_key=True),
@@ -2029,7 +2050,7 @@ RLSFilterRoles = Table(
Column("rls_filter_id", Integer, ForeignKey("row_level_security_filters.id")),
)
-RLSFilterTables = Table(
+RLSFilterTables = DBTable(
"rls_filter_tables",
metadata,
Column("id", Integer, primary_key=True),
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 4bc11aee42d..87b3d5dd3a2 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -38,7 +38,7 @@ from superset.exceptions import (
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
-from superset.sql_parse import ParsedQuery
+from superset.sql_parse import ParsedQuery, Table
from superset.superset_typing import ResultSetColumnType
if TYPE_CHECKING:
@@ -47,24 +47,18 @@ if TYPE_CHECKING:
def get_physical_table_metadata(
database: Database,
- table_name: str,
+ table: Table,
normalize_columns: bool,
- schema_name: str | None = None,
) -> list[ResultSetColumnType]:
"""Use SQLAlchemy inspector to get table metadata"""
db_engine_spec = database.db_engine_spec
db_dialect = database.get_dialect()
- # ensure empty schema
- _schema_name = schema_name if schema_name else None
+
# Table does not exist or is not visible to a connection.
+ if not (database.has_table(table) or database.has_view(table)):
+ raise NoSuchTableError(table)
- if not (
- database.has_table_by_name(table_name=table_name, schema=_schema_name)
- or database.has_view_by_name(view_name=table_name, schema=_schema_name)
- ):
- raise NoSuchTableError
-
- cols = database.get_columns(table_name, schema=_schema_name)
+ cols = database.get_columns(table)
for col in cols:
try:
if isinstance(col["type"], TypeEngine):
@@ -129,11 +123,17 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
level=ErrorLevel.ERROR,
)
)
- return get_columns_description(dataset.database, dataset.schema, statements[0])
+ return get_columns_description(
+ dataset.database,
+ dataset.catalog,
+ dataset.schema,
+ statements[0],
+ )
def get_columns_description(
database: Database,
+ catalog: str | None,
schema: str | None,
query: str,
) -> list[ResultSetColumnType]:
@@ -141,7 +141,7 @@ def get_columns_description(
# sql_lab.py:execute_sql_statements
db_engine_spec = database.db_engine_spec
try:
- with database.get_raw_connection(schema=schema) as conn:
+ with database.get_raw_connection(catalog=catalog, schema=schema) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
cursor.execute(query)
diff --git a/superset/constants.py b/superset/constants.py
index e4d467bdd83..28902ded6cf 100644
--- a/superset/constants.py
+++ b/superset/constants.py
@@ -134,6 +134,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = {
"schemas": "read",
"select_star": "read",
"table_metadata": "read",
+ "table_metadata_deprecated": "read",
"table_extra_metadata": "read",
"table_extra_metadata_deprecated": "read",
"test_connection": "write",
diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py
index 23b46e3329d..21c5ae1d0fa 100644
--- a/superset/daos/dataset.py
+++ b/superset/daos/dataset.py
@@ -30,6 +30,7 @@ from superset.extensions import db
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils.core import DatasourceType
from superset.views.base import DatasourceFilter
@@ -72,25 +73,26 @@ class DatasetDAO(BaseDAO[SqlaTable]):
@staticmethod
def validate_table_exists(
- database: Database, table_name: str, schema: str | None
+ database: Database,
+ table: Table,
) -> bool:
try:
- database.get_table(table_name, schema=schema)
+ database.get_table(table)
return True
except SQLAlchemyError as ex: # pragma: no cover
- logger.warning("Got an error %s validating table: %s", str(ex), table_name)
+ logger.warning("Got an error %s validating table: %s", str(ex), table)
return False
@staticmethod
def validate_uniqueness(
database_id: int,
- schema: str | None,
- name: str,
+ table: Table,
dataset_id: int | None = None,
) -> bool:
dataset_query = db.session.query(SqlaTable).filter(
- SqlaTable.table_name == name,
- SqlaTable.schema == schema,
+ SqlaTable.table_name == table.table,
+ SqlaTable.schema == table.schema,
+ SqlaTable.catalog == table.catalog,
SqlaTable.database_id == database_id,
)
@@ -103,14 +105,14 @@ class DatasetDAO(BaseDAO[SqlaTable]):
@staticmethod
def validate_update_uniqueness(
database_id: int,
- schema: str | None,
+ table: Table,
dataset_id: int,
- name: str,
) -> bool:
dataset_query = db.session.query(SqlaTable).filter(
- SqlaTable.table_name == name,
+ SqlaTable.table_name == table.table,
SqlaTable.database_id == database_id,
- SqlaTable.schema == schema,
+ SqlaTable.schema == table.schema,
+ SqlaTable.catalog == table.catalog,
SqlaTable.id != dataset_id,
)
return not db.session.query(dataset_query.exists()).scalar()
diff --git a/superset/databases/api.py b/superset/databases/api.py
index 635a2da790b..a77019123b9 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -136,6 +136,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
RouteMethod.RELATED,
"tables",
"table_metadata",
+ "table_metadata_deprecated",
"table_extra_metadata",
"table_extra_metadata_deprecated",
"select_star",
@@ -722,10 +723,10 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
- f".table_metadata",
+ f".table_metadata_deprecated",
log_to_statsd=False,
)
- def table_metadata(
+ def table_metadata_deprecated(
self, database: Database, table_name: str, schema_name: str
) -> FlaskResponse:
"""Get database table metadata.
@@ -766,16 +767,16 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
500:
$ref: '#/components/responses/500'
"""
- self.incr_stats("init", self.table_metadata.__name__)
+ self.incr_stats("init", self.table_metadata_deprecated.__name__)
try:
- table_info = get_table_metadata(database, table_name, schema_name)
+ table_info = get_table_metadata(database, Table(table_name, schema_name))
except SQLAlchemyError as ex:
- self.incr_stats("error", self.table_metadata.__name__)
+ self.incr_stats("error", self.table_metadata_deprecated.__name__)
return self.response_422(error_msg_from_exception(ex))
except SupersetException as ex:
return self.response(ex.status, message=ex.message)
- self.incr_stats("success", self.table_metadata.__name__)
+ self.incr_stats("success", self.table_metadata_deprecated.__name__)
return self.response(200, **table_info)
@expose("//table_extra///", methods=("GET",))
@@ -844,7 +845,86 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
payload = database.db_engine_spec.get_extra_table_metadata(database, table)
return self.response(200, **payload)
- @expose("//table_metadata/extra/", methods=("GET",))
+ @expose("//table_metadata/", methods=["GET"])
+ @protect()
+ @statsd_metrics
+ @event_logger.log_this_with_context(
+ action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
+ f".table_metadata",
+ log_to_statsd=False,
+ )
+ def table_metadata(self, pk: int) -> FlaskResponse:
+ """
+ Get metadata for a given table.
+
+ Optionally, a schema and a catalog can be passed, if different from the default
+ ones.
+ ---
+ get:
+ summary: Get table metadata
+ description: >-
+ Metadata associated with the table (columns, indexes, etc.)
+ parameters:
+ - in: path
+ schema:
+ type: integer
+ name: pk
+ description: The database id
+ - in: query
+ schema:
+ type: string
+ name: table
+ required: true
+ description: Table name
+ - in: query
+ schema:
+ type: string
+ name: schema
+ description: >-
+ Optional table schema, if not passed default schema will be used
+ - in: query
+ schema:
+ type: string
+ name: catalog
+ description: >-
+ Optional table catalog, if not passed default catalog will be used
+ responses:
+ 200:
+ description: Table metadata information
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/TableExtraMetadataResponseSchema"
+ 401:
+ $ref: '#/components/responses/401'
+ 404:
+ $ref: '#/components/responses/404'
+ 500:
+ $ref: '#/components/responses/500'
+ """
+ self.incr_stats("init", self.table_metadata.__name__)
+
+ database = DatabaseDAO.find_by_id(pk)
+ if database is None:
+ raise DatabaseNotFoundException("No such database")
+
+ try:
+ parameters = QualifiedTableSchema().load(request.args)
+ except ValidationError as ex:
+ raise InvalidPayloadSchemaError(ex) from ex
+
+ table = Table(parameters["name"], parameters["schema"], parameters["catalog"])
+ try:
+ security_manager.raise_for_access(database=database, table=table)
+ except SupersetSecurityException as ex:
+ # instead of raising 403, raise 404 to hide table existence
+ raise TableNotFoundException("No such table") from ex
+
+ payload = database.db_engine_spec.get_table_metadata(database, table)
+
+ return self.response(200, **payload)
+
+ @expose("//table_metadata/extra/", methods=["GET"])
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
@@ -978,7 +1058,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
self.incr_stats("init", self.select_star.__name__)
try:
result = database.select_star(
- table_name, schema_name, latest_partition=True
+ Table(table_name, schema_name),
+ latest_partition=True,
)
except NoSuchTableError:
self.incr_stats("error", self.select_star.__name__)
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index 9a1fc9d6c15..1bc0af7472c 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -17,11 +17,13 @@
# pylint: disable=unused-argument, too-many-lines
+from __future__ import annotations
+
import inspect
import json
import os
import re
-from typing import Any
+from typing import Any, TypedDict
from flask import current_app
from flask_babel import lazy_gettext as _
@@ -581,6 +583,49 @@ class DatabaseTestConnectionSchema(DatabaseParametersSchemaMixin, Schema):
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
+class TableMetadataOptionsResponse(TypedDict):
+ deferrable: bool
+ initially: bool
+ match: bool
+ ondelete: bool
+ onupdate: bool
+
+
+class TableMetadataColumnsResponse(TypedDict, total=False):
+ keys: list[str]
+ longType: str
+ name: str
+ type: str
+ duplicates_constraint: str | None
+ comment: str | None
+
+
+class TableMetadataForeignKeysIndexesResponse(TypedDict):
+ column_names: list[str]
+ name: str
+ options: TableMetadataOptionsResponse
+ referred_columns: list[str]
+ referred_schema: str
+ referred_table: str
+ type: str
+
+
+class TableMetadataPrimaryKeyResponse(TypedDict):
+ column_names: list[str]
+ name: str
+ type: str
+
+
+class TableMetadataResponse(TypedDict):
+ name: str
+ columns: list[TableMetadataColumnsResponse]
+ foreignKeys: list[TableMetadataForeignKeysIndexesResponse]
+ indexes: list[TableMetadataForeignKeysIndexesResponse]
+ primaryKey: TableMetadataPrimaryKeyResponse
+ selectStar: str
+ comment: str | None
+
+
class TableMetadataOptionsResponseSchema(Schema):
deferrable = fields.Bool()
initially = fields.Bool()
diff --git a/superset/databases/utils.py b/superset/databases/utils.py
index 8de4bb6f235..dfd75eb2233 100644
--- a/superset/databases/utils.py
+++ b/superset/databases/utils.py
@@ -14,19 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Optional, Union
+
+from __future__ import annotations
+
+from typing import Any, TYPE_CHECKING
from sqlalchemy.engine.url import make_url, URL
from superset.commands.database.exceptions import DatabaseInvalidError
+from superset.sql_parse import Table
+
+if TYPE_CHECKING:
+ from superset.databases.schemas import (
+ TableMetadataColumnsResponse,
+ TableMetadataForeignKeysIndexesResponse,
+ TableMetadataResponse,
+ )
def get_foreign_keys_metadata(
database: Any,
- table_name: str,
- schema_name: Optional[str],
-) -> list[dict[str, Any]]:
- foreign_keys = database.get_foreign_keys(table_name, schema_name)
+ table: Table,
+) -> list[TableMetadataForeignKeysIndexesResponse]:
+ foreign_keys = database.get_foreign_keys(table)
for fk in foreign_keys:
fk["column_names"] = fk.pop("constrained_columns")
fk["type"] = "fk"
@@ -34,9 +44,10 @@ def get_foreign_keys_metadata(
def get_indexes_metadata(
- database: Any, table_name: str, schema_name: Optional[str]
-) -> list[dict[str, Any]]:
- indexes = database.get_indexes(table_name, schema_name)
+ database: Any,
+ table: Table,
+) -> list[TableMetadataForeignKeysIndexesResponse]:
+ indexes = database.get_indexes(table)
for idx in indexes:
idx["type"] = "index"
return indexes
@@ -51,30 +62,27 @@ def get_col_type(col: dict[Any, Any]) -> str:
return dtype
-def get_table_metadata(
- database: Any, table_name: str, schema_name: Optional[str]
-) -> dict[str, Any]:
+def get_table_metadata(database: Any, table: Table) -> TableMetadataResponse:
"""
Get table metadata information, including type, pk, fks.
This function raises SQLAlchemyError when a schema is not found.
:param database: The database model
- :param table_name: Table name
- :param schema_name: schema name
+ :param table: Table instance
:return: Dict table metadata ready for API response
"""
keys = []
- columns = database.get_columns(table_name, schema_name)
- primary_key = database.get_pk_constraint(table_name, schema_name)
+ columns = database.get_columns(table)
+ primary_key = database.get_pk_constraint(table)
if primary_key and primary_key.get("constrained_columns"):
primary_key["column_names"] = primary_key.pop("constrained_columns")
primary_key["type"] = "pk"
keys += [primary_key]
- foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
- indexes = get_indexes_metadata(database, table_name, schema_name)
+ foreign_keys = get_foreign_keys_metadata(database, table)
+ indexes = get_indexes_metadata(database, table)
keys += foreign_keys + indexes
- payload_columns: list[dict[str, Any]] = []
- table_comment = database.get_table_comment(table_name, schema_name)
+ payload_columns: list[TableMetadataColumnsResponse] = []
+ table_comment = database.get_table_comment(table)
for col in columns:
dtype = get_col_type(col)
payload_columns.append(
@@ -87,11 +95,10 @@ def get_table_metadata(
}
)
return {
- "name": table_name,
+ "name": table.table,
"columns": payload_columns,
"selectStar": database.select_star(
- table_name,
- schema=schema_name,
+ table,
indent=True,
cols=columns,
latest_partition=True,
@@ -103,7 +110,7 @@ def get_table_metadata(
}
-def make_url_safe(raw_url: Union[str, URL]) -> URL:
+def make_url_safe(raw_url: str | URL) -> URL:
"""
Wrapper for SQLAlchemy's make_url(), which tends to raise too detailed of
errors, which inevitably find their way into server logs. ArgumentErrors
diff --git a/superset/db_engine_specs/README.md b/superset/db_engine_specs/README.md
index 11f0f90ab7d..4a108be6587 100644
--- a/superset/db_engine_specs/README.md
+++ b/superset/db_engine_specs/README.md
@@ -660,7 +660,7 @@ This way, when a user selects a column that doesn't exist Superset can return a
### Dynamic schema
-In SQL Lab it's possible to select a database, and then a schema in that database. Ideally, when running a query in SQL Lab, any unqualified table names (eg, `table`, instead of `schema.table`) should be in the selected schema. For example, if the user select `dev` as the schema and then runs the following query:
+In SQL Lab it's possible to select a database, and then a schema in that database. Ideally, when running a query in SQL Lab, any unqualified table names (eg, `table`, instead of `schema.table`) should be in the selected schema. For example, if the user selects `dev` as the schema and then runs the following query:
```sql
SELECT * FROM my_table
@@ -674,7 +674,7 @@ Implementing this method is also important for usability. When the method is not
### Catalog
-In general, databases support a hierarchy of concepts of one-to-many concepts:
+In general, databases support a hierarchy of one-to-many concepts:
1. Database
2. Catalog
@@ -692,7 +692,7 @@ These concepts have different names depending on the database. For example, Post
BigQuery, on the other hand:
-1. Bigquery (database)
+1. BigQuery (database)
2. Project (catalog)
3. Schema (namespace)
4. Table
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 1fc6a40a3a2..3cc13151295 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -61,7 +61,7 @@ from sqlparse.tokens import CTE
from superset import sql_parse
from superset.constants import TimeGrain as TimeGrainConstants
-from superset.databases.utils import make_url_safe
+from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.sql_parse import ParsedQuery, SQLScript, Table
@@ -80,6 +80,7 @@ from superset.utils.oauth2 import encode_oauth2_state
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
+ from superset.databases.schemas import TableMetadataResponse
from superset.models.core import Database
from superset.models.sql_lab import Query
@@ -638,11 +639,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return driver in cls.drivers
@classmethod
- def get_default_schema(cls, database: Database) -> str | None:
+ def get_default_schema(cls, database: Database, catalog: str | None) -> str | None:
"""
Return the default schema in a given database.
"""
- with database.get_inspector_with_context() as inspector:
+ with database.get_inspector(catalog=catalog) as inspector:
return inspector.default_schema_name
@classmethod
@@ -697,7 +698,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return schema
# return the default schema of the database
- return cls.get_default_schema(database)
+ return cls.get_default_schema(database, query.catalog)
@classmethod
def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
@@ -760,18 +761,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_engine(
cls,
database: Database,
+ catalog: str | None = None,
schema: str | None = None,
source: utils.QuerySource | None = None,
) -> ContextManager[Engine]:
"""
Return an engine context manager.
- >>> with DBEngineSpec.get_engine(database, schema, source) as engine:
+ >>> with DBEngineSpec.get_engine(database, catalog, schema, source) as engine:
... connection = engine.connect()
... connection.execute(sql)
"""
- return database.get_sqla_engine(schema=schema, source=source)
+ return database.get_sqla_engine(catalog=catalog, schema=schema, source=source)
@classmethod
def get_timestamp_expr(
@@ -1033,6 +1035,21 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
return indexes
+ @classmethod
+ def get_table_metadata(
+ cls,
+ database: Database,
+ table: Table,
+ ) -> TableMetadataResponse:
+ """
+ Returns basic table metadata
+
+ :param database: Database instance
+ :param table: A Table instance
+ :return: Basic table metadata
+ """
+ return get_table_metadata(database, table)
+
@classmethod
def get_extra_table_metadata(
cls,
@@ -1236,7 +1253,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# Only add schema when it is preset and non-empty.
to_sql_kwargs["schema"] = table.schema
- with cls.get_engine(database) as engine:
+ with cls.get_engine(
+ database,
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as engine:
if engine.dialect.supports_multivalues_insert:
to_sql_kwargs["method"] = "multi"
@@ -1471,36 +1492,34 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database, # pylint: disable=unused-argument
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
- :param table_name: The table to inspect
- :param schema: The schema to inspect
+ :param table: The table instance to inspect
:returns: The indexes
"""
- return inspector.get_indexes(table_name, schema)
+ return inspector.get_indexes(table.table, table.schema)
@classmethod
def get_table_comment(
- cls, inspector: Inspector, table_name: str, schema: str | None
+ cls,
+ inspector: Inspector,
+ table: Table,
) -> str | None:
"""
Get comment of table from a given schema and table
-
:param inspector: SqlAlchemy Inspector instance
- :param table_name: Table name
- :param schema: Schema name. If omitted, uses default schema for database
+ :param table: Table instance
:return: comment of table
"""
comment = None
try:
- comment = inspector.get_table_comment(table_name, schema)
+ comment = inspector.get_table_comment(table.table, table.schema)
comment = comment.get("text") if isinstance(comment, dict) else None
except NotImplementedError:
# It's expected that some dialects don't implement the comment method
@@ -1514,22 +1533,25 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_columns( # pylint: disable=unused-argument
cls,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
- Get all columns from a given schema and table
+ Get all columns from a given schema and table.
+
+ The inspector will be bound to a catalog, if one was specified.
:param inspector: SqlAlchemy Inspector instance
- :param table_name: Table name
- :param schema: Schema name. If omitted, uses default schema for database
+ :param table: Table instance
:param options: Extra options to customise the display of columns in
some databases
:return: All columns in table
"""
return convert_inspector_columns(
- cast(list[SQLAColumnType], inspector.get_columns(table_name, schema))
+ cast(
+ list[SQLAColumnType],
+ inspector.get_columns(table.table, table.schema),
+ )
)
@classmethod
@@ -1537,8 +1559,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
) -> list[MetricType]:
"""
Get all metrics from a given schema and table.
@@ -1553,19 +1574,17 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
]
@classmethod
- def where_latest_partition( # pylint: disable=too-many-arguments,unused-argument
+ def where_latest_partition( # pylint: disable=unused-argument
cls,
- table_name: str,
- schema: str | None,
database: Database,
+ table: Table,
query: Select,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
"""
Add a where clause to a query to reference only the most recent partition
- :param table_name: Table name
- :param schema: Schema name
+ :param table: Table instance
:param database: Database instance
:param query: SqlAlchemy query
:param columns: List of TableColumns
@@ -1588,9 +1607,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def select_star( # pylint: disable=too-many-arguments,too-many-locals
cls,
database: Database,
- table_name: str,
+ table: Table,
engine: Engine,
- schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
@@ -1603,9 +1621,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
WARNING: expects only unquoted table and schema names.
:param database: Database instance
- :param table_name: Table name, unquoted
+ :param table: Table instance
:param engine: SqlAlchemy Engine instance
- :param schema: Schema, unquoted
:param limit: limit to impose on query
:param show_cols: Show columns in query; otherwise use "*"
:param indent: Add indentation to query
@@ -1617,16 +1634,18 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
fields: str | list[Any] = "*"
cols = cols or []
if (show_cols or latest_partition) and not cols:
- cols = database.get_columns(table_name, schema)
+ cols = database.get_columns(table)
if show_cols:
fields = cls._get_fields(cols)
+
quote = engine.dialect.identifier_preparer.quote
quote_schema = engine.dialect.identifier_preparer.quote_schema
- if schema:
- full_table_name = quote_schema(schema) + "." + quote(table_name)
- else:
- full_table_name = quote(table_name)
+ full_table_name = (
+ quote_schema(table.schema) + "." + quote(table.table)
+ if table.schema
+ else quote(table.table)
+ )
qry = select(fields).select_from(text(full_table_name))
@@ -1634,7 +1653,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
qry = qry.limit(limit)
if latest_partition:
partition_query = cls.where_latest_partition(
- table_name, schema, database, qry, columns=cols
+ database,
+ table,
+ qry,
+ columns=cols,
)
if partition_query is not None:
qry = partition_query
@@ -1685,9 +1707,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return database.mutate_sql_based_on_config(sql, is_split=True)
@classmethod
- def estimate_query_cost(
+ def estimate_query_cost( # pylint: disable=too-many-arguments
cls,
database: Database,
+ catalog: str | None,
schema: str,
sql: str,
source: utils.QuerySource | None = None,
@@ -1709,14 +1732,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
- costs = []
- with database.get_raw_connection(schema=schema, source=source) as conn:
+ with database.get_raw_connection(
+ catalog=catalog,
+ schema=schema,
+ source=source,
+ ) as conn:
cursor = conn.cursor()
- for statement in statements:
- processed_statement = cls.process_statement(statement, database)
- costs.append(cls.estimate_statement_cost(processed_statement, cursor))
-
- return costs
+ return [
+ cls.estimate_statement_cost(
+ cls.process_statement(statement, database),
+ cursor,
+ )
+ for statement in statements
+ ]
@classmethod
def get_url_for_impersonation(
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index 78d845450fc..8a2612f5b09 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -14,13 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+from __future__ import annotations
+
import contextlib
import json
import re
import urllib
from datetime import datetime
from re import Pattern
-from typing import Any, Optional, TYPE_CHECKING, TypedDict
+from typing import Any, TYPE_CHECKING, TypedDict
import pandas as pd
from apispec import APISpec
@@ -220,8 +223,8 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CAST('{dttm.date().isoformat()}' AS DATE)"
@@ -234,9 +237,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return None
@classmethod
- def fetch_data(
- cls, cursor: Any, limit: Optional[int] = None
- ) -> list[tuple[Any, ...]]:
+ def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Support type BigQuery Row, introduced here PR #4071
# google.cloud.bigquery.table.Row
@@ -302,30 +303,28 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def get_indexes(
cls,
- database: "Database",
+ database: Database,
inspector: Inspector,
- table_name: str,
- schema: Optional[str],
+ table: Table,
) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
- :param table_name: The table to inspect
- :param schema: The schema to inspect
+ :param table: The table instance to inspect
:returns: The indexes
"""
- return cls.normalize_indexes(inspector.get_indexes(table_name, schema))
+ return cls.normalize_indexes(inspector.get_indexes(table.table, table.schema))
@classmethod
def get_extra_table_metadata(
cls,
- database: "Database",
+ database: Database,
table: Table,
) -> dict[str, Any]:
- indexes = database.get_indexes(table.table, table.schema)
+ indexes = database.get_indexes(table)
if not indexes:
return {}
partitions_columns = [
@@ -354,7 +353,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def df_to_sql(
cls,
- database: "Database",
+ database: Database,
table: Table,
df: pd.DataFrame,
to_sql_kwargs: dict[str, Any],
@@ -380,7 +379,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
raise SupersetException("The table schema must be defined")
to_gbq_kwargs = {}
- with cls.get_engine(database) as engine:
+ with cls.get_engine(
+ database,
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as engine:
to_gbq_kwargs = {
"destination_table": str(table),
"project_id": engine.url.host,
@@ -403,7 +406,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
pandas_gbq.to_gbq(df, **to_gbq_kwargs)
@classmethod
- def _get_client(cls, engine: Engine) -> Any:
+ def _get_client(cls, engine: Engine) -> bigquery.Client:
"""
Return the BigQuery client associated with an engine.
"""
@@ -418,17 +421,19 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return bigquery.Client(credentials=credentials)
@classmethod
- def estimate_query_cost(
+ def estimate_query_cost( # pylint: disable=too-many-arguments
cls,
- database: "Database",
+ database: Database,
+ catalog: str | None,
schema: str,
sql: str,
- source: Optional[utils.QuerySource] = None,
+ source: utils.QuerySource | None = None,
) -> list[dict[str, Any]]:
"""
Estimate the cost of a multiple statement SQL query.
:param database: Database instance
+ :param catalog: Database project
:param schema: Database schema
:param sql: SQL query with possibly multiple statements
:param source: Source of the query (eg, "sql_lab")
@@ -439,17 +444,25 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
- costs = []
- for statement in statements:
- processed_statement = cls.process_statement(statement, database)
- costs.append(cls.estimate_statement_cost(processed_statement, database))
- return costs
+ with cls.get_engine(
+ database,
+ catalog=catalog,
+ schema=schema,
+ ) as engine:
+ client = cls._get_client(engine)
+ return [
+ cls.custom_estimate_statement_cost(
+ cls.process_statement(statement, database),
+ client,
+ )
+ for statement in statements
+ ]
@classmethod
def get_catalog_names(
cls,
- database: "Database",
+ database: Database,
inspector: Inspector,
) -> list[str]:
"""
@@ -469,14 +482,16 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return True
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
- with cls.get_engine(cursor) as engine:
- client = cls._get_client(engine)
- job_config = bigquery.QueryJobConfig(dry_run=True)
- query_job = client.query(
- statement,
- job_config=job_config,
- ) # Make an API request.
+ def custom_estimate_statement_cost(
+ cls,
+ statement: str,
+ client: bigquery.Client,
+ ) -> dict[str, Any]:
+ """
+ Custom version that receives a client instead of a cursor.
+ """
+ job_config = bigquery.QueryJobConfig(dry_run=True)
+ query_job = client.query(statement, job_config=job_config)
# Format Bytes.
# TODO: Humanize in case more db engine specs need to be added,
@@ -514,7 +529,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
def build_sqlalchemy_uri(
cls,
parameters: BigQueryParametersType,
- encrypted_extra: Optional[dict[str, Any]] = None,
+ encrypted_extra: dict[str, Any] | None = None,
) -> str:
query = parameters.get("query", {})
query_params = urllib.parse.urlencode(query)
@@ -536,7 +551,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
def get_parameters_from_uri(
cls,
uri: str,
- encrypted_extra: Optional[dict[str, Any]] = None,
+ encrypted_extra: dict[str, Any] | None = None,
) -> Any:
value = make_url_safe(uri)
@@ -549,7 +564,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
raise ValidationError("Invalid service credentials")
@classmethod
- def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]:
+ def mask_encrypted_extra(cls, encrypted_extra: str | None) -> str | None:
if encrypted_extra is None:
return encrypted_extra
@@ -563,9 +578,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return json.dumps(config)
@classmethod
- def unmask_encrypted_extra(
- cls, old: Optional[str], new: Optional[str]
- ) -> Optional[str]:
+ def unmask_encrypted_extra(cls, old: str | None, new: str | None) -> str | None:
"""
Reuse ``private_key`` if available and unchanged.
"""
@@ -628,15 +641,14 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def select_star( # pylint: disable=too-many-arguments
cls,
- database: "Database",
- table_name: str,
+ database: Database,
+ table: Table,
engine: Engine,
- schema: Optional[str] = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
- cols: Optional[list[ResultSetColumnType]] = None,
+ cols: list[ResultSetColumnType] | None = None,
) -> str:
"""
Remove array structures from `SELECT *`.
@@ -690,9 +702,8 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return super().select_star(
database,
- table_name,
+ table,
engine,
- schema,
limit,
show_cols,
indent,
diff --git a/superset/db_engine_specs/db2.py b/superset/db_engine_specs/db2.py
index db2e500b53d..b2151767d2d 100644
--- a/superset/db_engine_specs/db2.py
+++ b/superset/db_engine_specs/db2.py
@@ -21,6 +21,7 @@ from sqlalchemy.engine.reflection import Inspector
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
+from superset.sql_parse import Table
logger = logging.getLogger(__name__)
@@ -64,7 +65,9 @@ class Db2EngineSpec(BaseEngineSpec):
@classmethod
def get_table_comment(
- cls, inspector: Inspector, table_name: str, schema: Union[str, None]
+ cls,
+ inspector: Inspector,
+ table: Table,
) -> Optional[str]:
"""
Get comment of table from a given schema
@@ -72,13 +75,12 @@ class Db2EngineSpec(BaseEngineSpec):
Ibm Db2 return comments as tuples, so we need to get the first element
:param inspector: SqlAlchemy Inspector instance
- :param table_name: Table name
- :param schema: Schema name. If omitted, uses default schema for database
+ :param table: Table instance
:return: comment of table
"""
comment = None
try:
- table_comment = inspector.get_table_comment(table_name, schema)
+ table_comment = inspector.get_table_comment(table.table, table.schema)
comment = table_comment.get("text")
return comment[0]
except IndexError:
diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py
index 4aed2693f46..7606e93b500 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -142,7 +142,10 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
database: Database,
table: Table,
) -> dict[str, Any]:
- with database.get_raw_connection(schema=table.schema) as conn:
+ with database.get_raw_connection(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table.table}")')
results = cursor.fetchone()[0]
@@ -395,7 +398,11 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
pass
# get the Google session from the Shillelagh adapter
- with cls.get_engine(database) as engine:
+ with cls.get_engine(
+ database,
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as engine:
with engine.connect() as conn:
# any GSheets URL will work to get a working session
adapter = get_adapter_for_table_name(
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index a10f5f66bc9..80892b59877 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -205,7 +205,11 @@ class HiveEngineSpec(PrestoEngineSpec):
if table_exists:
raise SupersetException("Table already exists")
elif to_sql_kwargs["if_exists"] == "replace":
- with cls.get_engine(database) as engine:
+ with cls.get_engine(
+ database,
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as engine:
engine.execute(f"DROP TABLE IF EXISTS {str(table)}")
def _get_hive_type(dtype: np.dtype[Any]) -> str:
@@ -227,7 +231,11 @@ class HiveEngineSpec(PrestoEngineSpec):
) as file:
pq.write_table(pa.Table.from_pandas(df), where=file.name)
- with cls.get_engine(database) as engine:
+ with cls.get_engine(
+ database,
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as engine:
engine.execute(
text(
f"""
@@ -410,24 +418,24 @@ class HiveEngineSpec(PrestoEngineSpec):
def get_columns(
cls,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
- return BaseEngineSpec.get_columns(inspector, table_name, schema, options)
+ return BaseEngineSpec.get_columns(inspector, table, options)
@classmethod
- def where_latest_partition( # pylint: disable=too-many-arguments
+ def where_latest_partition(
cls,
- table_name: str,
- schema: str | None,
database: Database,
+ table: Table,
query: Select,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
try:
col_names, values = cls.latest_partition(
- table_name, schema, database, show_first=True
+ database,
+ table,
+ show_first=True,
)
except Exception: # pylint: disable=broad-except
# table is not partitioned
@@ -447,7 +455,10 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def latest_sub_partition( # type: ignore
- cls, table_name: str, schema: str | None, database: Database, **kwargs: Any
+ cls,
+ database: Database,
+ table: Table,
+ **kwargs: Any,
) -> str:
# TODO(bogdan): implement`
pass
@@ -465,24 +476,24 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def _partition_query( # pylint: disable=too-many-arguments
cls,
- table_name: str,
- schema: str | None,
+ table: Table,
indexes: list[dict[str, Any]],
database: Database,
limit: int = 0,
order_by: list[tuple[str, bool]] | None = None,
filters: dict[Any, Any] | None = None,
) -> str:
- full_table_name = f"{schema}.{table_name}" if schema else table_name
+ full_table_name = (
+ f"{table.schema}.{table.table}" if table.schema else table.table
+ )
return f"SHOW PARTITIONS {full_table_name}"
@classmethod
def select_star( # pylint: disable=too-many-arguments
cls,
database: Database,
- table_name: str,
+ table: Table,
engine: Engine,
- schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
@@ -491,9 +502,8 @@ class HiveEngineSpec(PrestoEngineSpec):
) -> str:
return super(PrestoEngineSpec, cls).select_star(
database,
- table_name,
+ table,
engine,
- schema,
limit,
show_cols,
indent,
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 8a803d3f140..34c47eb522c 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -420,8 +420,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unused-argument
cls,
- table_name: str,
- schema: str | None,
+ table: Table,
indexes: list[dict[str, Any]],
database: Database,
limit: int = 0,
@@ -434,8 +433,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
Note the unused arguments are exposed for sub-classing purposes where custom
integrations may require the schema, indexes, etc. to build the partition query.
- :param table_name: the name of the table to get partitions from
- :param schema: the schema name
+ :param table: the table instance
:param indexes: the indexes associated with the table
:param database: the database the query will be run against
:param limit: the number of partitions to be returned
@@ -464,12 +462,16 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
presto_version = database.get_extra().get("version")
if presto_version and Version(presto_version) < Version("0.199"):
- full_table_name = f"{schema}.{table_name}" if schema else table_name
+ full_table_name = (
+ f"{table.schema}.{table.table}" if table.schema else table.table
+ )
partition_select_clause = f"SHOW PARTITIONS FROM {full_table_name}"
else:
- system_table_name = f'"{table_name}$partitions"'
+ system_table_name = f'"{table.table}$partitions"'
full_table_name = (
- f"{schema}.{system_table_name}" if schema else system_table_name
+ f"{table.schema}.{system_table_name}"
+ if table.schema
+ else system_table_name
)
partition_select_clause = f"SELECT * FROM {full_table_name}"
@@ -484,18 +486,15 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return sql
@classmethod
- def where_latest_partition( # pylint: disable=too-many-arguments
+ def where_latest_partition(
cls,
- table_name: str,
- schema: str | None,
database: Database,
+ table: Table,
query: Select,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
try:
- col_names, values = cls.latest_partition(
- table_name, schema, database, show_first=True
- )
+ col_names, values = cls.latest_partition(database, table, show_first=True)
except Exception: # pylint: disable=broad-except
# table is not partitioned
return None
@@ -527,18 +526,16 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
@cache_manager.data_cache.memoize(timeout=60)
- def latest_partition( # pylint: disable=too-many-arguments
+ def latest_partition(
cls,
- table_name: str,
- schema: str | None,
database: Database,
+ table: Table,
show_first: bool = False,
indexes: list[dict[str, Any]] | None = None,
) -> tuple[list[str], list[str] | None]:
"""Returns col name and the latest (max) partition value for a table
- :param table_name: the name of the table
- :param schema: schema / database / namespace
+ :param table: the table instance
:param database: database query will be run against
:type database: models.Database
:param show_first: displays the value for the first partitioning key
@@ -550,11 +547,11 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
(['ds'], ('2018-01-01',))
"""
if indexes is None:
- indexes = database.get_indexes(table_name, schema)
+ indexes = database.get_indexes(table)
if not indexes:
raise SupersetTemplateException(
- f"Error getting partition for {schema}.{table_name}. "
+ f"Error getting partition for {table}. "
"Verify that this table has a partition."
)
@@ -575,20 +572,23 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return column_names, cls._latest_partition_from_df(
df=database.get_df(
sql=cls._partition_query(
- table_name,
- schema,
+ table,
indexes,
database,
limit=1,
order_by=[(column_name, True) for column_name in column_names],
),
- schema=schema,
+ catalog=table.catalog,
+ schema=table.schema,
)
)
@classmethod
def latest_sub_partition(
- cls, table_name: str, schema: str | None, database: Database, **kwargs: Any
+ cls,
+ database: Database,
+ table: Table,
+ **kwargs: Any,
) -> Any:
"""Returns the latest (max) partition value for a table
@@ -601,12 +601,9 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
``latest_sub_partition('my_table',
event_category='page', event_type='click')``
- :param table_name: the name of the table, can be just the table
- name or a fully qualified table name as ``schema_name.table_name``
- :type table_name: str
- :param schema: schema / database / namespace
- :type schema: str
:param database: database query will be run against
+ :param table: the table instance
+ :type table: Table
:type database: models.Database
:param kwargs: keyword arguments define the filtering criteria
@@ -615,7 +612,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
>>> latest_sub_partition('sub_partition_table', event_type='click')
'2018-01-01'
"""
- indexes = database.get_indexes(table_name, schema)
+ indexes = database.get_indexes(table)
part_fields = indexes[0]["column_names"]
for k in kwargs.keys(): # pylint: disable=consider-iterating-dictionary
if k not in k in part_fields: # pylint: disable=comparison-with-itself
@@ -633,15 +630,14 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
field_to_return = field
sql = cls._partition_query(
- table_name,
- schema,
+ table,
indexes,
database,
limit=1,
order_by=[(field_to_return, True)],
filters=kwargs,
)
- df = database.get_df(sql, schema)
+ df = database.get_df(sql, table.catalog, table.schema)
if df.empty:
return ""
return df.to_dict()[field_to_return][0]
@@ -966,40 +962,39 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def _show_columns(
- cls, inspector: Inspector, table_name: str, schema: str | None
+ cls,
+ inspector: Inspector,
+ table: Table,
) -> list[ResultRow]:
"""
Show presto column names
:param inspector: object that performs database schema inspection
- :param table_name: table name
- :param schema: schema name
+ :param table: table instance
:return: list of column objects
"""
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
- full_table = quote(table_name)
- if schema:
- full_table = f"{quote(schema)}.{full_table}"
+ full_table = quote(table.table)
+ if table.schema:
+ full_table = f"{quote(table.schema)}.{full_table}"
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
@classmethod
def get_columns(
cls,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
Get columns from a Presto data source. This includes handling row and
array data types
:param inspector: object that performs database schema inspection
- :param table_name: table name
- :param schema: schema name
+ :param table: table instance
:param options: Extra configuration options, not used by this backend
:return: a list of results that contain column info
(i.e. column name and data type)
"""
- columns = cls._show_columns(inspector, table_name, schema)
+ columns = cls._show_columns(inspector, table)
result: list[ResultSetColumnType] = []
for column in columns:
# parse column if it is a row or array
@@ -1077,9 +1072,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
def select_star( # pylint: disable=too-many-arguments
cls,
database: Database,
- table_name: str,
+ table: Table,
engine: Engine,
- schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
@@ -1102,9 +1096,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
]
return super().select_star(
database,
- table_name,
+ table,
engine,
- schema,
limit,
show_cols,
indent,
@@ -1232,11 +1225,10 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
) -> dict[str, Any]:
metadata = {}
- if indexes := database.get_indexes(table.table, table.schema):
+ if indexes := database.get_indexes(table):
col_names, latest_parts = cls.latest_partition(
- table.table,
- table.schema,
database,
+ table,
show_first=True,
indexes=indexes,
)
@@ -1248,8 +1240,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
"cols": sorted(indexes[0].get("column_names", [])),
"latest": dict(zip(col_names, latest_parts)),
"partitionQuery": cls._partition_query(
- table_name=table.table,
- schema=table.schema,
+ table=table,
indexes=indexes,
database=database,
),
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index 32c119ec1f6..08a38894e66 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -40,12 +40,12 @@ from superset.db_engine_specs.exceptions import (
)
from superset.db_engine_specs.presto import PrestoBaseEngineSpec
from superset.models.sql_lab import Query
+from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils
if TYPE_CHECKING:
from superset.models.core import Database
- from superset.sql_parse import Table
with contextlib.suppress(ImportError): # trino may not be installed
from trino.dbapi import Cursor
@@ -66,11 +66,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
) -> dict[str, Any]:
metadata = {}
- if indexes := database.get_indexes(table.table, table.schema):
+ if indexes := database.get_indexes(table):
col_names, latest_parts = cls.latest_partition(
- table.table,
- table.schema,
database,
+ table,
show_first=True,
indexes=indexes,
)
@@ -91,15 +90,17 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
),
"latest": dict(zip(col_names, latest_parts)),
"partitionQuery": cls._partition_query(
- table_name=table.table,
- schema=table.schema,
+ table=table,
indexes=indexes,
database=database,
),
}
- if database.has_view_by_name(table.table, table.schema):
- with database.get_inspector_with_context() as inspector:
+ if database.has_view(Table(table.table, table.schema)):
+ with database.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
metadata["view"] = inspector.get_view_definition(
table.table,
table.schema,
@@ -414,8 +415,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
def get_columns(
cls,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
@@ -423,7 +423,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
"schema_options", expand the schema definition out to show all
subfields of nested ROWs as their appropriate dotted paths.
"""
- base_cols = super().get_columns(inspector, table_name, schema, options)
+ base_cols = super().get_columns(inspector, table, options)
if not (options or {}).get("expand_rows"):
return base_cols
@@ -434,8 +434,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
- table_name: str,
- schema: str | None,
+ table: Table,
) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
@@ -444,11 +443,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
- :param table_name: The table to inspect
- :param schema: The schema to inspect
+ :param table: The table instance to inspect
:returns: The indexes
"""
try:
- return super().get_indexes(database, inspector, table_name, schema)
+ return super().get_indexes(database, inspector, table)
except NoSuchTableError:
return []
diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py
index 9ce27d49523..efbb8302015 100644
--- a/superset/examples/bart_lines.py
+++ b/superset/examples/bart_lines.py
@@ -21,6 +21,7 @@ import polyline
from sqlalchemy import inspect, String, Text
from superset import db
+from superset.sql_parse import Table
from ..utils.database import get_example_database
from .helpers import get_example_url, get_table_connector_registry
@@ -31,7 +32,7 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
database = get_example_database()
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
- table_exists = database.has_table_by_name(tbl_name)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
url = get_example_url("bart-lines.json.gz")
diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py
index 2e711bef290..7b7928c5321 100644
--- a/superset/examples/birth_names.py
+++ b/superset/examples/birth_names.py
@@ -27,6 +27,7 @@ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils.core import DatasourceType
from ..utils.database import get_example_database
@@ -95,7 +96,7 @@ def load_birth_names(
schema = inspect(engine).default_schema_name
tbl_name = "birth_names"
- table_exists = database.has_table_by_name(tbl_name, schema=schema)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
load_data(tbl_name, database, sample=sample)
diff --git a/superset/examples/country_map.py b/superset/examples/country_map.py
index 59c257bc80b..1741219470a 100644
--- a/superset/examples/country_map.py
+++ b/superset/examples/country_map.py
@@ -24,6 +24,7 @@ import superset.utils.database as database_utils
from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils.core import DatasourceType
from .helpers import (
@@ -42,7 +43,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
- table_exists = database.has_table_by_name(tbl_name)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
url = get_example_url("birth_france_data_for_country_map.csv")
diff --git a/superset/examples/energy.py b/superset/examples/energy.py
index 16d4eea3741..98b444f9db2 100644
--- a/superset/examples/energy.py
+++ b/superset/examples/energy.py
@@ -26,6 +26,7 @@ import superset.utils.database as database_utils
from superset import db
from superset.connectors.sqla.models import SqlMetric
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils.core import DatasourceType
from .helpers import (
@@ -45,7 +46,7 @@ def load_energy(
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
- table_exists = database.has_table_by_name(tbl_name)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
url = get_example_url("energy.json.gz")
diff --git a/superset/examples/flights.py b/superset/examples/flights.py
index 1e22fed4688..4db029519fd 100644
--- a/superset/examples/flights.py
+++ b/superset/examples/flights.py
@@ -19,6 +19,7 @@ from sqlalchemy import DateTime, inspect
import superset.utils.database as database_utils
from superset import db
+from superset.sql_parse import Table
from .helpers import get_example_url, get_table_connector_registry
@@ -29,7 +30,7 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
database = database_utils.get_example_database()
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
- table_exists = database.has_table_by_name(tbl_name)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
flight_data_url = get_example_url("flight_data.csv.gz")
diff --git a/superset/examples/long_lat.py b/superset/examples/long_lat.py
index 95cccadc240..4f8de31453c 100644
--- a/superset/examples/long_lat.py
+++ b/superset/examples/long_lat.py
@@ -24,6 +24,7 @@ from sqlalchemy import DateTime, Float, inspect, String
import superset.utils.database as database_utils
from superset import db
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils.core import DatasourceType
from .helpers import (
@@ -41,7 +42,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
database = database_utils.get_example_database()
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
- table_exists = database.has_table_by_name(tbl_name)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
url = get_example_url("san_francisco.csv.gz")
diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py
index 91799b2c2cf..979be10686f 100644
--- a/superset/examples/multiformat_time_series.py
+++ b/superset/examples/multiformat_time_series.py
@@ -21,6 +21,7 @@ from sqlalchemy import BigInteger, Date, DateTime, inspect, String
from superset import app, db
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils.core import DatasourceType
from ..utils.database import get_example_database
@@ -41,7 +42,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
database = get_example_database()
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
- table_exists = database.has_table_by_name(tbl_name)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
url = get_example_url("multiformat_time_series.json.gz")
diff --git a/superset/examples/paris.py b/superset/examples/paris.py
index cea784be775..1cd6c84d92d 100644
--- a/superset/examples/paris.py
+++ b/superset/examples/paris.py
@@ -21,6 +21,7 @@ from sqlalchemy import inspect, String, Text
import superset.utils.database as database_utils
from superset import db
+from superset.sql_parse import Table
from .helpers import get_example_url, get_table_connector_registry
@@ -30,7 +31,7 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) ->
database = database_utils.get_example_database()
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
- table_exists = database.has_table_by_name(tbl_name)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
url = get_example_url("paris_iris.json.gz")
diff --git a/superset/examples/random_time_series.py b/superset/examples/random_time_series.py
index 9b5306781df..ec232995fa2 100644
--- a/superset/examples/random_time_series.py
+++ b/superset/examples/random_time_series.py
@@ -21,6 +21,7 @@ from sqlalchemy import DateTime, inspect, String
import superset.utils.database as database_utils
from superset import app, db
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils.core import DatasourceType
from .helpers import (
@@ -39,7 +40,7 @@ def load_random_time_series_data(
database = database_utils.get_example_database()
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
- table_exists = database.has_table_by_name(tbl_name)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
url = get_example_url("random_time_series.json.gz")
diff --git a/superset/examples/sf_population_polygons.py b/superset/examples/sf_population_polygons.py
index d97ffd3ae51..d4754887c72 100644
--- a/superset/examples/sf_population_polygons.py
+++ b/superset/examples/sf_population_polygons.py
@@ -21,6 +21,7 @@ from sqlalchemy import BigInteger, Float, inspect, Text
import superset.utils.database as database_utils
from superset import db
+from superset.sql_parse import Table
from .helpers import get_example_url, get_table_connector_registry
@@ -32,7 +33,7 @@ def load_sf_population_polygons(
database = database_utils.get_example_database()
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
- table_exists = database.has_table_by_name(tbl_name)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
url = get_example_url("sf_population.json.gz")
diff --git a/superset/examples/supported_charts_dashboard.py b/superset/examples/supported_charts_dashboard.py
index 371f03d18b1..ae0962fc172 100644
--- a/superset/examples/supported_charts_dashboard.py
+++ b/superset/examples/supported_charts_dashboard.py
@@ -26,6 +26,7 @@ from superset import db
from superset.connectors.sqla.models import SqlaTable
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils.core import DatasourceType
from ..utils.database import get_example_database
@@ -443,7 +444,7 @@ def load_supported_charts_dashboard() -> None:
schema = inspect(engine).default_schema_name
tbl_name = "birth_names"
- table_exists = database.has_table_by_name(tbl_name, schema=schema)
+ table_exists = database.has_table(Table(tbl_name, schema))
if table_exists:
table = get_table_connector_registry()
diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py
index c98c1fc11ce..7b6b3749213 100644
--- a/superset/examples/world_bank.py
+++ b/superset/examples/world_bank.py
@@ -37,6 +37,7 @@ from superset.examples.helpers import (
)
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils import core as utils
from superset.utils.core import DatasourceType
@@ -51,7 +52,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
database = superset.utils.database.get_example_database()
with database.get_sqla_engine() as engine:
schema = inspect(engine).default_schema_name
- table_exists = database.has_table_by_name(tbl_name)
+ table_exists = database.has_table(Table(tbl_name, schema))
if not only_metadata and (not table_exists or force):
url = get_example_url("countries.json.gz")
diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py
index 0d33ac97e2b..2d8444cc993 100644
--- a/superset/extensions/metadb.py
+++ b/superset/extensions/metadb.py
@@ -271,6 +271,8 @@ class SupersetShillelaghAdapter(Adapter):
self.catalog = parts.pop(-1) if parts else None
if self.catalog:
+ # TODO (betodealmeida): when SIP-95 is implemented we should check to see if
+ # the database has multi-catalog enabled, and if so, give access.
raise NotImplementedError("Catalogs are not currently supported")
# If the table has a single integer primary key we use that as the row ID in order
@@ -314,7 +316,8 @@ class SupersetShillelaghAdapter(Adapter):
# store this callable for later whenever we need an engine
self.engine_context = partial(
database.get_sqla_engine,
- self.schema,
+ catalog=self.catalog,
+ schema=self.schema,
)
# fetch column names and types
diff --git a/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py b/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py
new file mode 100644
index 00000000000..ec5733e1510
--- /dev/null
+++ b/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Add catalog column
+
+Revision ID: 5f57af97bc3f
+Revises: d60591c5515f
+Create Date: 2024-04-11 15:41:34.663989
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "5f57af97bc3f"
+down_revision = "d60591c5515f"
+
+
+def upgrade():
+ op.add_column("tables", sa.Column("catalog", sa.String(length=256), nullable=True))
+ op.add_column("query", sa.Column("catalog", sa.String(length=256), nullable=True))
+ op.add_column(
+ "saved_query",
+ sa.Column("catalog", sa.String(length=256), nullable=True),
+ )
+ op.add_column(
+ "tab_state",
+ sa.Column("catalog", sa.String(length=256), nullable=True),
+ )
+ op.add_column(
+ "table_schema",
+ sa.Column("catalog", sa.String(length=256), nullable=True),
+ )
+
+
+def downgrade():
+ op.drop_column("table_schema", "catalog")
+ op.drop_column("tab_state", "catalog")
+ op.drop_column("saved_query", "catalog")
+ op.drop_column("query", "catalog")
+ op.drop_column("tables", "catalog")
diff --git a/superset/models/core.py b/superset/models/core.py
index 42f6a78244f..9a4a1de4037 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=too-many-lines
+# pylint: disable=too-many-lines, too-many-arguments
+
"""A collection of ORM sqlalchemy models for Superset"""
from __future__ import annotations
@@ -46,7 +47,7 @@ from sqlalchemy import (
Integer,
MetaData,
String,
- Table,
+ Table as SqlaTable,
Text,
)
from sqlalchemy.engine import Connection, Dialect, Engine
@@ -73,6 +74,7 @@ from superset.extensions import (
)
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.result_set import SupersetResultSet
+from superset.sql_parse import Table
from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
from superset.utils import cache as cache_util, core as utils
from superset.utils.backports import StrEnum
@@ -382,13 +384,22 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
)
@contextmanager
- def get_sqla_engine(
+ def get_sqla_engine( # pylint: disable=too-many-arguments
self,
+ catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
override_ssh_tunnel: SSHTunnel | None = None,
) -> Engine:
+ """
+ Context manager for a SQLAlchemy engine.
+
+ This method will return a context manager for a SQLAlchemy engine. Using the
+ context manager (as opposed to the engine directly) is important because we need
+ to potentially establish SSH tunnels before the connection is created, and clean
+ them up once the engine is no longer used.
+ """
from superset.daos.database import ( # pylint: disable=import-outside-toplevel
DatabaseDAO,
)
@@ -403,7 +414,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
# if ssh_tunnel is available build engine with information
engine_context = ssh_manager_factory.instance.create_tunnel(
ssh_tunnel=ssh_tunnel,
- sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted,
+ sqlalchemy_database_uri=sqlalchemy_uri,
)
with engine_context as server_context:
@@ -415,22 +426,21 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
server_context.local_bind_address,
)
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
- sqlalchemy_uri, server_context
+ sqlalchemy_uri,
+ server_context,
)
+
yield self._get_sqla_engine(
+ catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
- # The `get_sqla_engine_with_context` was renamed to `get_sqla_engine`, but we kept a
- # reference to the old method to prevent breaking third-party applications.
- # TODO (betodealmeida): Remove in 5.0
- get_sqla_engine_with_context = get_sqla_engine
-
def _get_sqla_engine(
self,
+ catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
@@ -447,26 +457,10 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
params["poolclass"] = NullPool
connect_args = params.get("connect_args", {})
- # The ``adjust_database_uri`` method was renamed to ``adjust_engine_params`` and
- # had its signature changed in order to support more DB engine specs. Since DB
- # engine specs can be released as 3rd party modules we want to make sure the old
- # method is still supported so we don't introduce a breaking change.
- if hasattr(self.db_engine_spec, "adjust_database_uri"):
- sqlalchemy_url = self.db_engine_spec.adjust_database_uri(
- sqlalchemy_url,
- schema,
- )
- logger.warning(
- "DB engine spec %s implements the method `adjust_database_uri`, which is "
- "deprecated and will be removed in version 3.0. Please update it to "
- "implement `adjust_engine_params` instead.",
- self.db_engine_spec,
- )
-
sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
uri=sqlalchemy_url,
connect_args=connect_args,
- catalog=None,
+ catalog=catalog,
schema=schema,
)
@@ -532,17 +526,24 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
@contextmanager
def get_raw_connection(
self,
+ catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
) -> Connection:
with self.get_sqla_engine(
- schema=schema, nullpool=nullpool, source=source
+ catalog=catalog,
+ schema=schema,
+ nullpool=nullpool,
+ source=source,
) as engine:
with closing(engine.raw_connection()) as conn:
# pre-session queries are used to set the selected schema and, in the
# future, the selected catalog
- for prequery in self.db_engine_spec.get_prequeries(schema=schema):
+ for prequery in self.db_engine_spec.get_prequeries(
+ catalog=catalog,
+ schema=schema,
+ ):
cursor = conn.cursor()
cursor.execute(prequery)
@@ -606,14 +607,15 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
)
return sql_
- def get_df(
+ def get_df( # pylint: disable=too-many-locals
self,
sql: str,
+ catalog: str | None = None,
schema: str | None = None,
mutator: Callable[[pd.DataFrame], None] | None = None,
) -> pd.DataFrame:
sqls = self.db_engine_spec.parse_sql(sql)
- with self.get_sqla_engine(schema) as engine:
+ with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
engine_url = engine.url
def _log_query(sql: str) -> None:
@@ -626,7 +628,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
security_manager,
)
- with self.get_raw_connection(schema=schema) as conn:
+ with self.get_raw_connection(catalog=catalog, schema=schema) as conn:
cursor = conn.cursor()
df = None
for i, sql_ in enumerate(sqls):
@@ -653,8 +655,13 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
return self.post_process_df(df)
- def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
- with self.get_sqla_engine(schema) as engine:
+ def compile_sqla_query(
+ self,
+ qry: Select,
+ catalog: str | None = None,
+ schema: str | None = None,
+ ) -> str:
+ with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
# pylint: disable=protected-access
@@ -665,8 +672,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
def select_star( # pylint: disable=too-many-arguments
self,
- table_name: str,
- schema: str | None = None,
+ table: Table,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
@@ -674,11 +680,10 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
cols: list[ResultSetColumnType] | None = None,
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
- with self.get_sqla_engine(schema) as engine:
+ with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
return self.db_engine_spec.select_star(
self,
- table_name,
- schema=schema,
+ table,
engine=engine,
limit=limit,
show_cols=show_cols,
@@ -703,6 +708,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
)
def get_all_table_names_in_schema( # pylint: disable=unused-argument
self,
+ catalog: str | None,
schema: str,
cache: bool = False,
cache_timeout: int | None = None,
@@ -720,7 +726,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
:return: The table/schema pairs
"""
try:
- with self.get_inspector_with_context() as inspector:
+ with self.get_inspector(catalog=catalog, schema=schema) as inspector:
return {
(table, schema)
for table in self.db_engine_spec.get_table_names(
@@ -738,6 +744,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
)
def get_all_view_names_in_schema( # pylint: disable=unused-argument
self,
+ catalog: str | None,
schema: str,
cache: bool = False,
cache_timeout: int | None = None,
@@ -755,7 +762,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
:return: set of views
"""
try:
- with self.get_inspector_with_context() as inspector:
+ with self.get_inspector(catalog=catalog, schema=schema) as inspector:
return {
(view, schema)
for view in self.db_engine_spec.get_view_names(
@@ -768,10 +775,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
@contextmanager
- def get_inspector_with_context(
- self, ssh_tunnel: SSHTunnel | None = None
+ def get_inspector(
+ self,
+ catalog: str | None = None,
+ schema: str | None = None,
+ ssh_tunnel: SSHTunnel | None = None,
) -> Inspector:
- with self.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine:
+ with self.get_sqla_engine(
+ catalog=catalog,
+ schema=schema,
+ override_ssh_tunnel=ssh_tunnel,
+ ) as engine:
yield sqla.inspect(engine)
@cache_util.memoized_func(
@@ -780,6 +794,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
)
def get_all_schema_names( # pylint: disable=unused-argument
self,
+ catalog: str | None = None,
cache: bool = False,
cache_timeout: int | None = None,
force: bool = False,
@@ -796,7 +811,10 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
:return: schema list
"""
try:
- with self.get_inspector_with_context(ssh_tunnel=ssh_tunnel) as inspector:
+ with self.get_inspector(
+ catalog=catalog,
+ ssh_tunnel=ssh_tunnel,
+ ) as inspector:
return self.db_engine_spec.get_schema_names(inspector)
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@@ -848,51 +866,57 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None:
self.db_engine_spec.update_params_from_encrypted_extra(self, params)
- def get_table(self, table_name: str, schema: str | None = None) -> Table:
+ def get_table(self, table: Table) -> SqlaTable:
extra = self.get_extra()
meta = MetaData(**extra.get("metadata_params", {}))
- with self.get_sqla_engine() as engine:
- return Table(
- table_name,
+ with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
+ return SqlaTable(
+ table.table,
meta,
- schema=schema or None,
+ schema=table.schema or None,
autoload=True,
autoload_with=engine,
)
- def get_table_comment(
- self, table_name: str, schema: str | None = None
- ) -> str | None:
- with self.get_inspector_with_context() as inspector:
- return self.db_engine_spec.get_table_comment(inspector, table_name, schema)
+ def get_table_comment(self, table: Table) -> str | None:
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
+ return self.db_engine_spec.get_table_comment(inspector, table)
- def get_columns(
- self, table_name: str, schema: str | None = None
- ) -> list[ResultSetColumnType]:
- with self.get_inspector_with_context() as inspector:
+ def get_columns(self, table: Table) -> list[ResultSetColumnType]:
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
return self.db_engine_spec.get_columns(
- inspector, table_name, schema, self.schema_options
+ inspector, table, self.schema_options
)
def get_metrics(
self,
- table_name: str,
- schema: str | None = None,
+ table: Table,
) -> list[MetricType]:
- with self.get_inspector_with_context() as inspector:
- return self.db_engine_spec.get_metrics(self, inspector, table_name, schema)
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
+ return self.db_engine_spec.get_metrics(self, inspector, table)
- def get_indexes(
- self, table_name: str, schema: str | None = None
- ) -> list[dict[str, Any]]:
- with self.get_inspector_with_context() as inspector:
- return self.db_engine_spec.get_indexes(self, inspector, table_name, schema)
+ def get_indexes(self, table: Table) -> list[dict[str, Any]]:
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
+ return self.db_engine_spec.get_indexes(self, inspector, table)
- def get_pk_constraint(
- self, table_name: str, schema: str | None = None
- ) -> dict[str, Any]:
- with self.get_inspector_with_context() as inspector:
- pk_constraint = inspector.get_pk_constraint(table_name, schema) or {}
+ def get_pk_constraint(self, table: Table) -> dict[str, Any]:
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
+ pk_constraint = inspector.get_pk_constraint(table.table, table.schema) or {}
def _convert(value: Any) -> Any:
try:
@@ -902,11 +926,12 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
return {key: _convert(value) for key, value in pk_constraint.items()}
- def get_foreign_keys(
- self, table_name: str, schema: str | None = None
- ) -> list[dict[str, Any]]:
- with self.get_inspector_with_context() as inspector:
- return inspector.get_foreign_keys(table_name, schema)
+ def get_foreign_keys(self, table: Table) -> list[dict[str, Any]]:
+ with self.get_inspector(
+ catalog=table.catalog,
+ schema=table.schema,
+ ) as inspector:
+ return inspector.get_foreign_keys(table.table, table.schema)
def get_schema_access_for_file_upload( # pylint: disable=invalid-name
self,
@@ -955,36 +980,23 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
return self.perm # type: ignore
def has_table(self, table: Table) -> bool:
- with self.get_sqla_engine() as engine:
- return engine.has_table(table.table_name, table.schema or None)
+ with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
+ # do not pass "" as an empty schema; force null
+ return engine.has_table(table.table, table.schema or None)
- def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool:
- with self.get_sqla_engine() as engine:
- return engine.has_table(table_name, schema)
+ def has_view(self, table: Table) -> bool:
+ with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine:
+ connection = engine.connect()
+ try:
+ views = engine.dialect.get_view_names(
+ connection=connection,
+ schema=table.schema,
+ )
+ except Exception: # pylint: disable=broad-except
+ logger.warning("Has view failed", exc_info=True)
+ views = []
- @classmethod
- def _has_view(
- cls,
- conn: Connection,
- dialect: Dialect,
- view_name: str,
- schema: str | None = None,
- ) -> bool:
- view_names: list[str] = []
- try:
- view_names = dialect.get_view_names(connection=conn, schema=schema)
- except Exception: # pylint: disable=broad-except
- logger.warning("Has view failed", exc_info=True)
- return view_name in view_names
-
- def has_view(self, view_name: str, schema: str | None = None) -> bool:
- with self.get_sqla_engine(schema) as engine:
- return engine.run_callable(
- self._has_view, engine.dialect, view_name, schema
- )
-
- def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool:
- return self.has_view(view_name=view_name, schema=schema)
+ return table.table in views
def get_dialect(self) -> Dialect:
sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted)
diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index caf90067007..991d2d41a46 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -32,7 +32,6 @@ from sqlalchemy import (
Column,
ForeignKey,
Integer,
- MetaData,
String,
Table,
Text,
@@ -214,13 +213,6 @@ class Dashboard(AuditMixinNullable, ImportExportMixin, Model):
def charts(self) -> list[str]:
return [slc.chart for slc in self.slices]
- @property
- def sqla_metadata(self) -> None:
- # pylint: disable=no-member
- with self.get_sqla_engine() as engine:
- meta = MetaData(bind=engine)
- meta.reflect()
-
@property
def status(self) -> utils.DashboardStatus:
if self.published:
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 18d0f2e1d0f..40a5132c556 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -109,6 +109,7 @@ class Query(
tab_name = Column(String(256))
sql_editor_id = Column(String(256))
schema = Column(String(256))
+ catalog = Column(String(256), nullable=True, default=None)
sql = Column(MediumText())
# Query to retrieve the results,
# used only in case of select_as_cta_used is true.
@@ -386,6 +387,7 @@ class SavedQuery(
user_id = Column(Integer, ForeignKey("ab_user.id"), nullable=True)
db_id = Column(Integer, ForeignKey("dbs.id"), nullable=True)
schema = Column(String(128))
+ catalog = Column(String(256), nullable=True, default=None)
label = Column(String(256))
description = Column(Text)
sql = Column(MediumText())
@@ -474,6 +476,7 @@ class TabState(AuditMixinNullable, ExtraJSONMixin, Model):
database_id = Column(Integer, ForeignKey("dbs.id", ondelete="CASCADE"))
database = relationship("Database", foreign_keys=[database_id])
schema = Column(String(256))
+ catalog = Column(String(256), nullable=True, default=None)
# tables that are open in the schema browser and their data previews
table_schemas = relationship(
@@ -535,6 +538,7 @@ class TableSchema(AuditMixinNullable, ExtraJSONMixin, Model):
)
database = relationship("Database", foreign_keys=[database_id])
schema = Column(String(256))
+ catalog = Column(String(256), nullable=True, default=None)
table = Column(String(256))
# JSON describing the schema, partitions, latest partition, etc.
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 4da85b7d1f7..a84c0cec0d2 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -1922,6 +1922,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
table: Optional["Table"] = None,
viz: Optional["BaseViz"] = None,
sql: Optional[str] = None,
+ catalog: Optional[str] = None, # pylint: disable=unused-argument
schema: Optional[str] = None,
) -> None:
"""
@@ -1934,6 +1935,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
:param table: The Superset table (requires database)
:param viz: The visualization
:param sql: The SQL string (requires database)
+ :param catalog: Optional catalog name
:param schema: Optional schema name
:raises SupersetSecurityException: If the user cannot access the resource
"""
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 9076136c64f..3f8c1cc7370 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -55,6 +55,7 @@ from superset.sql_parse import (
insert_rls_as_subquery,
insert_rls_in_predicate,
ParsedQuery,
+ Table,
)
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import write_ipc_buffer
@@ -470,7 +471,11 @@ def execute_sql_statements(
)
)
- with database.get_raw_connection(query.schema, source=QuerySource.SQL_LAB) as conn:
+ with database.get_raw_connection(
+ catalog=query.catalog,
+ schema=query.schema,
+ source=QuerySource.SQL_LAB,
+ ) as conn:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
cursor = conn.cursor()
@@ -539,8 +544,7 @@ def execute_sql_statements(
query.set_extra_json_key("columns", result_set.columns)
if query.select_as_cta:
query.select_sql = database.select_star(
- query.tmp_table_name,
- schema=query.tmp_schema_name,
+ Table(query.tmp_table_name, query.tmp_schema_name),
limit=query.limit,
show_cols=False,
latest_partition=False,
@@ -645,7 +649,9 @@ def cancel_query(query: Query) -> bool:
return False
with query.database.get_sqla_engine(
- query.schema, source=QuerySource.SQL_LAB
+ catalog=query.catalog,
+ schema=query.schema,
+ source=QuerySource.SQL_LAB,
) as engine:
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py
index 8344fc9264d..25f73af0894 100644
--- a/superset/sql_validators/base.py
+++ b/superset/sql_validators/base.py
@@ -14,7 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Optional
+
+from __future__ import annotations
+
+from typing import Any
from superset.models.core import Database
@@ -25,9 +28,9 @@ class SQLValidationAnnotation: # pylint: disable=too-few-public-methods
def __init__(
self,
message: str,
- line_number: Optional[int],
- start_column: Optional[int],
- end_column: Optional[int],
+ line_number: int | None,
+ start_column: int | None,
+ end_column: int | None,
):
self.message = message
self.line_number = line_number
@@ -52,7 +55,11 @@ class BaseSQLValidator: # pylint: disable=too-few-public-methods
@classmethod
def validate(
- cls, sql: str, schema: Optional[str], database: Database
+ cls,
+ sql: str,
+ catalog: str | None,
+ schema: str | None,
+ database: Database,
) -> list[SQLValidationAnnotation]:
"""Check that the given SQL querystring is valid for the given engine"""
raise NotImplementedError
diff --git a/superset/sql_validators/postgres.py b/superset/sql_validators/postgres.py
index 60c15ca034c..279520292ea 100644
--- a/superset/sql_validators/postgres.py
+++ b/superset/sql_validators/postgres.py
@@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import re
-from typing import Optional
from pgsanity.pgsanity import check_string
@@ -31,7 +32,11 @@ class PostgreSQLValidator(BaseSQLValidator): # pylint: disable=too-few-public-m
@classmethod
def validate(
- cls, sql: str, schema: Optional[str], database: Database
+ cls,
+ sql: str,
+ catalog: str | None,
+ schema: str | None,
+ database: Database,
) -> list[SQLValidationAnnotation]:
annotations: list[SQLValidationAnnotation] = []
valid, error = check_string(sql, add_semicolon=True)
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index 4d4d898034c..06bee217cf2 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import logging
import time
from contextlib import closing
-from typing import Any, Optional
+from typing import Any
from superset import app
from superset.models.core import Database
@@ -47,7 +49,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
statement: str,
database: Database,
cursor: Any,
- ) -> Optional[SQLValidationAnnotation]:
+ ) -> SQLValidationAnnotation | None:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
@@ -140,7 +142,11 @@ class PrestoDBSQLValidator(BaseSQLValidator):
@classmethod
def validate(
- cls, sql: str, schema: Optional[str], database: Database
+ cls,
+ sql: str,
+ catalog: str | None,
+ schema: str | None,
+ database: Database,
) -> list[SQLValidationAnnotation]:
"""
Presto supports query-validation queries by running them with a
@@ -155,7 +161,11 @@ class PrestoDBSQLValidator(BaseSQLValidator):
logger.info("Validating %i statement(s)", len(statements))
# todo(hughhh): update this to use new database.get_raw_connection()
# this function keeps stalling CI
- with database.get_sqla_engine(schema, source=QuerySource.SQL_LAB) as engine:
+ with database.get_sqla_engine(
+ catalog=catalog,
+ schema=schema,
+ source=QuerySource.SQL_LAB,
+ ) as engine:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
annotations: list[SQLValidationAnnotation] = []
diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py
index fc082ecb456..5013d0954cf 100644
--- a/superset/utils/mock_data.py
+++ b/superset/utils/mock_data.py
@@ -29,12 +29,13 @@ from uuid import uuid4
import sqlalchemy.sql.sqltypes
import sqlalchemy_utils
from flask_appbuilder import Model
-from sqlalchemy import Column, inspect, MetaData, Table
+from sqlalchemy import Column, inspect, MetaData, Table as DBTable
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import func
from sqlalchemy.sql.visitors import VisitableType
from superset import db
+from superset.sql_parse import Table
logger = logging.getLogger(__name__)
@@ -182,7 +183,7 @@ def add_data(
from superset.utils.database import get_example_database
database = get_example_database()
- table_exists = database.has_table_by_name(table_name)
+ table_exists = database.has_table(Table(table_name))
with database.get_sqla_engine() as engine:
if columns is None:
@@ -198,7 +199,7 @@ def add_data(
# create table if needed
column_objects = get_column_objects(columns)
metadata = MetaData()
- table = Table(table_name, metadata, *column_objects)
+ table = DBTable(table_name, metadata, *column_objects)
metadata.create_all(engine)
if not append:
diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py
index eba3acf36ed..7f810817775 100644
--- a/superset/views/datasource/views.py
+++ b/superset/views/datasource/views.py
@@ -37,6 +37,7 @@ from superset.connectors.sqla.utils import get_physical_table_metadata
from superset.daos.datasource import DatasourceDAO
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.core import Database
+from superset.sql_parse import Table
from superset.superset_typing import FlaskResponse
from superset.utils.core import DatasourceType
from superset.views.base import (
@@ -180,8 +181,7 @@ class Datasource(BaseSupersetView):
)
external_metadata = get_physical_table_metadata(
database=database,
- table_name=params["table_name"],
- schema_name=params["schema_name"],
+ table=Table(params["table_name"], params["schema_name"]),
normalize_columns=params.get("normalize_columns") or False,
)
except (NoResultFound, NoSuchTableError) as ex:
diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py
index 48497b97794..3bd82211e5d 100644
--- a/tests/integration_tests/celery_tests.py
+++ b/tests/integration_tests/celery_tests.py
@@ -121,7 +121,7 @@ def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
def quote_f(value: Optional[str]):
if not value:
return value
- with get_example_database().get_inspector_with_context() as inspector:
+ with get_example_database().get_inspector() as inspector:
return inspector.engine.dialect.identifier_preparer.quote_identifier(value)
diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py
index 8122eac9d45..6d70a1cd75a 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -132,9 +132,7 @@ class BaseTestChartDataApi(SupersetTestCase):
def quote_name(self, name: str):
if get_main_database().backend in {"presto", "hive"}:
- with (
- get_example_database().get_inspector_with_context() as inspector
- ): # E: Ne
+ with get_example_database().get_inspector() as inspector: # E: Ne
return inspector.engine.dialect.identifier_preparer.quote_identifier(
name
)
diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py
index bcb9aa32919..90873d49b9f 100644
--- a/tests/integration_tests/core_tests.py
+++ b/tests/integration_tests/core_tests.py
@@ -50,6 +50,7 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
+from superset.sql_parse import Table
from superset.utils import core as utils
from superset.utils.core import backend
from superset.utils.database import get_example_database
@@ -1197,14 +1198,11 @@ class TestCore(SupersetTestCase):
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
- def test_has_table_by_name(self):
+ def test_has_table(self):
if backend() in ("sqlite", "mysql"):
return
example_db = superset.utils.database.get_example_database()
- assert (
- example_db.has_table_by_name(table_name="birth_names", schema="public")
- is True
- )
+ assert example_db.has_table(Table("birth_names", "public")) is True
@mock.patch("superset.views.core.request")
@mock.patch(
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index 016c8979884..ad7d71c7685 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -2031,7 +2031,7 @@ class TestDatabaseApi(SupersetTestCase):
if database.backend == "postgresql":
response = json.loads(rv.data.decode("utf-8"))
schemas = [
- s[0] for s in database.get_all_table_names_in_schema(schema_name)
+ s[0] for s in database.get_all_table_names_in_schema(None, schema_name)
]
self.assertEqual(response["count"], len(schemas))
for option in response["result"]:
diff --git a/tests/integration_tests/databases/commands/upload_test.py b/tests/integration_tests/databases/commands/upload_test.py
index 26379aa9769..1af85c3ab1f 100644
--- a/tests/integration_tests/databases/commands/upload_test.py
+++ b/tests/integration_tests/databases/commands/upload_test.py
@@ -73,7 +73,7 @@ def _setup_csv_upload(allowed_schemas: list[str] | None = None):
yield
upload_db = get_upload_db()
- with upload_db.get_sqla_engine_with_context() as engine:
+ with upload_db.get_sqla_engine() as engine:
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}")
db.session.delete(upload_db)
@@ -107,7 +107,7 @@ def test_csv_upload_with_nulls():
None,
CSVReader({"null_values": ["N/A", "None"]}),
).run()
- with upload_database.get_sqla_engine_with_context() as engine:
+ with upload_database.get_sqla_engine() as engine:
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
assert data == [
("name1", None, "city1", "1-1-1980"),
diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py
index 8d24c2993b7..c10d589d97f 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -773,14 +773,14 @@ class TestDatasetApi(SupersetTestCase):
assert rv.status_code == 422
@patch("superset.models.core.Database.get_columns")
- @patch("superset.models.core.Database.has_table_by_name")
- @patch("superset.models.core.Database.has_view_by_name")
+ @patch("superset.models.core.Database.has_table")
+ @patch("superset.models.core.Database.has_view")
@patch("superset.models.core.Database.get_table")
def test_create_dataset_validate_view_exists(
self,
mock_get_table,
- mock_has_table_by_name,
- mock_has_view_by_name,
+ mock_has_table,
+ mock_has_view,
mock_get_columns,
):
"""
@@ -796,13 +796,12 @@ class TestDatasetApi(SupersetTestCase):
}
]
- mock_has_table_by_name.return_value = False
- mock_has_view_by_name.return_value = True
+ mock_has_table.return_value = False
+ mock_has_view.return_value = True
mock_get_table.return_value = None
example_db = get_example_database()
with example_db.get_sqla_engine() as engine:
- engine = engine
dialect = engine.dialect
with patch.object(
diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
index d7498dc4fee..c8db1f912ad 100644
--- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
+++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
@@ -30,7 +30,7 @@ from superset.db_engine_specs.base import (
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.sql_parse import ParsedQuery
+from superset.sql_parse import ParsedQuery, Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.test_app import app
@@ -238,7 +238,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_column_datatype_to_string(self):
example_db = get_example_database()
- sqla_table = example_db.get_table("energy_usage")
+ sqla_table = example_db.get_table(Table("energy_usage"))
dialect = example_db.get_dialect()
# TODO: fix column type conversion for presto.
@@ -540,8 +540,7 @@ def test_get_indexes():
BaseEngineSpec.get_indexes(
database=mock.Mock(),
inspector=inspector,
- table_name="bar",
- schema="foo",
+ table=Table("bar", "foo"),
)
== indexes
)
diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py
index ce184685db5..53f9137076b 100644
--- a/tests/integration_tests/db_engine_specs/bigquery_tests.py
+++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py
@@ -165,8 +165,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
BigQueryEngineSpec.get_indexes(
database,
inspector,
- table_name,
- schema,
+ Table(table_name, schema),
)
== []
)
@@ -184,8 +183,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
- table_name,
- schema,
+ Table(table_name, schema),
) == [
{
"name": "partition",
@@ -207,8 +205,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
- table_name,
- schema,
+ Table(table_name, schema),
) == [
{
"name": "partition",
diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py
index 39d2c30fd11..4d1a8450816 100644
--- a/tests/integration_tests/db_engine_specs/hive_tests.py
+++ b/tests/integration_tests/db_engine_specs/hive_tests.py
@@ -23,7 +23,7 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
-from superset.sql_parse import Table, ParsedQuery
+from superset.sql_parse import ParsedQuery, Table
from tests.integration_tests.test_app import app
@@ -328,7 +328,10 @@ def test_where_latest_partition(mock_method):
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
- "test_table", "test_schema", database, select(), columns
+ database,
+ Table("test_table", "test_schema"),
+ select(),
+ columns,
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result
@@ -341,7 +344,10 @@ def test_where_latest_partition_super_method_exception(mock_method):
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
- "test_table", "test_schema", database, select(), columns
+ database,
+ Table("test_table", "test_schema"),
+ select(),
+ columns,
)
assert result is None
mock_method.assert_called()
@@ -353,7 +359,9 @@ def test_where_latest_partition_no_columns_no_values(mock_method):
db = mock.Mock()
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
- "test_table", "test_schema", db, select()
+ db,
+ Table("test_table", "test_schema"),
+ select(),
)
assert result is None
diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py
index 0f4841fb356..708b9498767 100644
--- a/tests/integration_tests/db_engine_specs/postgres_tests.py
+++ b/tests/integration_tests/db_engine_specs/postgres_tests.py
@@ -530,7 +530,7 @@ def test_get_catalog_names(app_context: AppContext) -> None:
if database.backend != "postgresql":
return
- with database.get_inspector_with_context() as inspector:
+ with database.get_inspector() as inspector:
assert PostgresEngineSpec.get_catalog_names(database, inspector) == [
"postgres",
"superset",
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py
index 02669a162fd..607afa6953f 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -82,7 +82,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
row = mock.Mock()
row.Column, row.Type, row.Null = column
inspector.bind.execute.return_value.fetchall = mock.Mock(return_value=[row])
- results = PrestoEngineSpec.get_columns(inspector, "", "")
+ results = PrestoEngineSpec.get_columns(inspector, Table("", ""))
self.assertEqual(len(expected_results), len(results))
for expected_result, result in zip(expected_results, results):
self.assertEqual(expected_result[0], result["column_name"])
@@ -573,7 +573,10 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
db.get_df = mock.Mock(return_value=df)
columns = [{"name": "ds"}, {"name": "hour"}]
result = PrestoEngineSpec.where_latest_partition(
- "test_table", "test_schema", db, select(), columns
+ db,
+ Table("test_table", "test_schema"),
+ select(),
+ columns,
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)
@@ -802,7 +805,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
return_value=["a", "b"]
)
table_name = "table_name"
- result = PrestoEngineSpec._show_columns(inspector, table_name, None)
+ result = PrestoEngineSpec._show_columns(inspector, Table(table_name))
assert result == ["a", "b"]
inspector.bind.execute.assert_called_once_with(
f'SHOW COLUMNS FROM "{table_name}"'
@@ -818,7 +821,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
)
table_name = "table_name"
schema = "schema"
- result = PrestoEngineSpec._show_columns(inspector, table_name, schema)
+ result = PrestoEngineSpec._show_columns(inspector, Table(table_name, schema))
assert result == ["a", "b"]
inspector.bind.execute.assert_called_once_with(
f'SHOW COLUMNS FROM "{schema}"."{table_name}"'
@@ -846,9 +849,16 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
{"col1": "val1"},
{"col2": "val2"},
]
- PrestoEngineSpec.select_star(database, table_name, engine, cols=cols)
+ PrestoEngineSpec.select_star(database, Table(table_name), engine, cols=cols)
mock_select_star.assert_called_once_with(
- database, table_name, engine, None, 100, False, True, True, cols
+ database,
+ Table(table_name),
+ engine,
+ 100,
+ False,
+ True,
+ True,
+ cols,
)
@mock.patch("superset.db_engine_specs.presto.is_feature_enabled")
@@ -869,13 +879,16 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
{"column_name": ".val2."},
]
PrestoEngineSpec.select_star(
- database, table_name, engine, show_cols=True, cols=cols
+ database,
+ Table(table_name),
+ engine,
+ show_cols=True,
+ cols=cols,
)
mock_select_star.assert_called_once_with(
database,
- table_name,
+ Table(table_name),
engine,
- None,
100,
True,
True,
@@ -1172,7 +1185,7 @@ def test_get_catalog_names(app_context: AppContext) -> None:
if database.backend != "presto":
return
- with database.get_inspector_with_context() as inspector:
+ with database.get_inspector() as inspector:
assert PrestoEngineSpec.get_catalog_names(database, inspector) == [
"jmx",
"memory",
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index 7c3bc15c39a..b68cb7c05f5 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -39,6 +39,7 @@ from superset.db_engine_specs.postgres import PostgresEngineSpec # noqa: F401
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.slice import Slice
+from superset.sql_parse import Table
from superset.utils.database import get_example_database
from .base_tests import SupersetTestCase
@@ -294,14 +295,14 @@ class TestDatabaseModel(SupersetTestCase):
def test_select_star(self):
db = get_example_database()
table_name = "energy_usage"
- sql = db.select_star(table_name, show_cols=False, latest_partition=False)
+ sql = db.select_star(Table(table_name), show_cols=False, latest_partition=False)
with db.get_sqla_engine() as engine:
quote = engine.dialect.identifier_preparer.quote_identifier
source = quote(table_name) if db.backend in {"presto", "hive"} else table_name
expected = f"SELECT\n *\nFROM {source}\nLIMIT 100"
assert expected in sql
- sql = db.select_star(table_name, show_cols=True, latest_partition=False)
+ sql = db.select_star(Table(table_name), show_cols=True, latest_partition=False)
# TODO(bkyryliuk): unify sql generation
if db.backend == "presto":
assert (
@@ -324,7 +325,9 @@ class TestDatabaseModel(SupersetTestCase):
schema = "schema.name"
table_name = "table/name"
sql = db.select_star(
- table_name, schema=schema, show_cols=False, latest_partition=False
+ Table(table_name, schema),
+ show_cols=False,
+ latest_partition=False,
)
fully_qualified_names = {
"sqlite": '"schema.name"."table/name"',
diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py
index 12cb530582e..c286bf3a438 100644
--- a/tests/integration_tests/sql_validator_tests.py
+++ b/tests/integration_tests/sql_validator_tests.py
@@ -56,7 +56,7 @@ class TestPrestoValidator(SupersetTestCase):
sql = "SELECT 1 FROM default.notarealtable"
schema = "default"
- errors = self.validator.validate(sql, schema, self.database)
+ errors = self.validator.validate(sql, None, schema, self.database)
self.assertEqual([], errors)
@@ -70,7 +70,7 @@ class TestPrestoValidator(SupersetTestCase):
fetch_fn.side_effect = DatabaseError("dummy db error")
with self.assertRaises(PrestoSQLValidationError):
- self.validator.validate(sql, schema, self.database)
+ self.validator.validate(sql, None, schema, self.database)
@patch("superset.utils.core.g")
def test_validator_unexpected_error(self, flask_g):
@@ -82,7 +82,7 @@ class TestPrestoValidator(SupersetTestCase):
fetch_fn.side_effect = Exception("a mysterious failure")
with self.assertRaises(Exception):
- self.validator.validate(sql, schema, self.database)
+ self.validator.validate(sql, None, schema, self.database)
@patch("superset.utils.core.g")
def test_validator_query_error(self, flask_g):
@@ -93,7 +93,7 @@ class TestPrestoValidator(SupersetTestCase):
fetch_fn = self.database.db_engine_spec.fetch_data
fetch_fn.side_effect = DatabaseError(self.PRESTO_ERROR_TEMPLATE)
- errors = self.validator.validate(sql, schema, self.database)
+ errors = self.validator.validate(sql, None, schema, self.database)
self.assertEqual(1, len(errors))
@@ -105,7 +105,10 @@ class TestPostgreSQLValidator(SupersetTestCase):
mock_database = MagicMock()
annotations = PostgreSQLValidator.validate(
- sql='SELECT 1, "col" FROM "table"', schema="", database=mock_database
+ sql='SELECT 1, "col" FROM "table"',
+ catalog=None,
+ schema="",
+ database=mock_database,
)
assert annotations == []
@@ -115,7 +118,10 @@ class TestPostgreSQLValidator(SupersetTestCase):
mock_database = MagicMock()
annotations = PostgreSQLValidator.validate(
- sql='SELECT 1, "col"\nFROOM "table"', schema="", database=mock_database
+ sql='SELECT 1, "col"\nFROOM "table"',
+ catalog=None,
+ schema="",
+ database=mock_database,
)
assert len(annotations) == 1
diff --git a/tests/unit_tests/dao/dataset_test.py b/tests/unit_tests/dao/dataset_test.py
index 1e3d1ec9750..a2e2b2b39fb 100644
--- a/tests/unit_tests/dao/dataset_test.py
+++ b/tests/unit_tests/dao/dataset_test.py
@@ -18,6 +18,7 @@
from sqlalchemy.orm.session import Session
from superset.daos.dataset import DatasetDAO
+from superset.sql_parse import Table
def test_validate_update_uniqueness(session: Session) -> None:
@@ -54,9 +55,8 @@ def test_validate_update_uniqueness(session: Session) -> None:
assert (
DatasetDAO.validate_update_uniqueness(
database_id=database.id,
- schema=dataset1.schema,
+ table=Table(dataset1.table_name, dataset1.schema),
dataset_id=dataset1.id,
- name=dataset1.table_name,
)
is True
)
@@ -65,9 +65,8 @@ def test_validate_update_uniqueness(session: Session) -> None:
assert (
DatasetDAO.validate_update_uniqueness(
database_id=database.id,
- schema=dataset2.schema,
+ table=Table(dataset1.table_name, dataset2.schema),
dataset_id=dataset1.id,
- name=dataset1.table_name,
)
is False
)
@@ -76,9 +75,8 @@ def test_validate_update_uniqueness(session: Session) -> None:
assert (
DatasetDAO.validate_update_uniqueness(
database_id=database.id,
- schema=None,
+ table=Table(dataset1.table_name),
dataset_id=dataset1.id,
- name=dataset1.table_name,
)
is True
)
diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py
index 2f3c11f9a35..154fd6501c5 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -1415,6 +1415,170 @@ def test_excel_upload_file_extension_invalid(
assert response.json == {"message": {"file": ["File extension is not allowed."]}}
+def test_table_metadata_happy_path(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `table_metadata` endpoint.
+ """
+ database = mocker.MagicMock()
+ database.db_engine_spec.get_table_metadata.return_value = {"hello": "world"}
+ mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database)
+ mocker.patch("superset.databases.api.security_manager.raise_for_access")
+
+ response = client.get("/api/v1/database/1/table_metadata/?name=t")
+ assert response.json == {"hello": "world"}
+ database.db_engine_spec.get_table_metadata.assert_called_with(
+ database,
+ Table("t"),
+ )
+
+ response = client.get("/api/v1/database/1/table_metadata/?name=t&schema=s")
+ database.db_engine_spec.get_table_metadata.assert_called_with(
+ database,
+ Table("t", "s"),
+ )
+
+ response = client.get("/api/v1/database/1/table_metadata/?name=t&catalog=c")
+ database.db_engine_spec.get_table_metadata.assert_called_with(
+ database,
+ Table("t", None, "c"),
+ )
+
+ response = client.get(
+ "/api/v1/database/1/table_metadata/?name=t&schema=s&catalog=c"
+ )
+ database.db_engine_spec.get_table_metadata.assert_called_with(
+ database,
+ Table("t", "s", "c"),
+ )
+
+
+def test_table_metadata_no_table(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `table_metadata` endpoint when no table name is passed.
+ """
+ database = mocker.MagicMock()
+ mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database)
+
+ response = client.get("/api/v1/database/1/table_metadata/?schema=s&catalog=c")
+ assert response.status_code == 422
+ assert response.json == {
+ "errors": [
+ {
+ "message": "An error happened when validating the request",
+ "error_type": "INVALID_PAYLOAD_SCHEMA_ERROR",
+ "level": "error",
+ "extra": {
+ "messages": {"name": ["Missing data for required field."]},
+ "issue_codes": [
+ {
+ "code": 1020,
+ "message": "Issue 1020 - The submitted payload has the incorrect schema.",
+ }
+ ],
+ },
+ }
+ ]
+ }
+
+
+def test_table_metadata_slashes(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `table_metadata` endpoint with names that have slashes.
+ """
+ database = mocker.MagicMock()
+ database.db_engine_spec.get_table_metadata.return_value = {"hello": "world"}
+ mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database)
+ mocker.patch("superset.databases.api.security_manager.raise_for_access")
+
+ client.get("/api/v1/database/1/table_metadata/?name=foo/bar")
+ database.db_engine_spec.get_table_metadata.assert_called_with(
+ database,
+ Table("foo/bar"),
+ )
+
+
+def test_table_metadata_invalid_database(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `table_metadata` endpoint when the database is invalid.
+ """
+ mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=None)
+
+ response = client.get("/api/v1/database/1/table_metadata/?name=t")
+ assert response.status_code == 404
+ assert response.json == {
+ "errors": [
+ {
+ "message": "No such database",
+ "error_type": "DATABASE_NOT_FOUND_ERROR",
+ "level": "error",
+ "extra": {
+ "issue_codes": [
+ {
+ "code": 1011,
+ "message": "Issue 1011 - Superset encountered an unexpected error.",
+ },
+ {
+ "code": 1036,
+ "message": "Issue 1036 - The database was deleted.",
+ },
+ ]
+ },
+ }
+ ]
+ }
+
+
+def test_table_metadata_unauthorized(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `table_metadata` endpoint when the user is unauthorized.
+ """
+ database = mocker.MagicMock()
+ mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database)
+ mocker.patch(
+ "superset.databases.api.security_manager.raise_for_access",
+ side_effect=SupersetSecurityException(
+ SupersetError(
+ error_type=SupersetErrorType.TABLE_SECURITY_ACCESS_ERROR,
+ message="You don't have access to the table",
+ level=ErrorLevel.ERROR,
+ )
+ ),
+ )
+
+ response = client.get("/api/v1/database/1/table_metadata/?name=t")
+ assert response.status_code == 404
+ assert response.json == {
+ "errors": [
+ {
+ "message": "No such table",
+ "error_type": "TABLE_NOT_FOUND_ERROR",
+ "level": "error",
+ "extra": None,
+ }
+ ]
+ }
+
+
def test_table_extra_metadata_happy_path(
mocker: MockFixture,
client: Any,
diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py
index e17e0d2833d..3bc05ee20ee 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -232,9 +232,8 @@ def test_select_star(mocker: MockFixture) -> None:
sql = BaseEngineSpec.select_star(
database=database,
- table_name="my_table",
+ table=Table("my_table"),
engine=engine,
- schema=None,
limit=100,
show_cols=True,
indent=True,
@@ -252,9 +251,8 @@ OFFSET ?"""
sql = NoLimitDBEngineSpec.select_star(
database=database,
- table_name="my_table",
+ table=Table("my_table"),
engine=engine,
- schema=None,
limit=100,
show_cols=True,
indent=True,
diff --git a/tests/unit_tests/db_engine_specs/test_bigquery.py b/tests/unit_tests/db_engine_specs/test_bigquery.py
index 663fd7cac84..616ae668418 100644
--- a/tests/unit_tests/db_engine_specs/test_bigquery.py
+++ b/tests/unit_tests/db_engine_specs/test_bigquery.py
@@ -27,6 +27,7 @@ from sqlalchemy import select
from sqlalchemy.sql import sqltypes
from sqlalchemy_bigquery import BigQueryDialect
+from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm # noqa: F401
@@ -156,9 +157,8 @@ def test_select_star(mocker: MockFixture) -> None:
sql = BigQueryEngineSpec.select_star(
database=database,
- table_name="my_table",
+ table=Table("my_table"),
engine=engine,
- schema=None,
limit=100,
show_cols=True,
indent=True,
diff --git a/tests/unit_tests/db_engine_specs/test_db2.py b/tests/unit_tests/db_engine_specs/test_db2.py
index 6d0d604a25f..017fcd7b80e 100644
--- a/tests/unit_tests/db_engine_specs/test_db2.py
+++ b/tests/unit_tests/db_engine_specs/test_db2.py
@@ -18,6 +18,8 @@
import pytest # noqa: F401
from pytest_mock import MockerFixture
+from superset.sql_parse import Table
+
def test_epoch_to_dttm() -> None:
"""
@@ -43,7 +45,7 @@ def test_get_table_comment(mocker: MockerFixture):
}
assert (
- Db2EngineSpec.get_table_comment(mock_inspector, "my_table", "my_schema")
+ Db2EngineSpec.get_table_comment(mock_inspector, Table("my_table", "my_schema"))
== "This is a table comment"
)
@@ -59,7 +61,8 @@ def test_get_table_comment_empty(mocker: MockerFixture):
mock_inspector.get_table_comment.return_value = {}
assert (
- Db2EngineSpec.get_table_comment(mock_inspector, "my_table", "my_schema") is None # noqa: E711
+ Db2EngineSpec.get_table_comment(mock_inspector, Table("my_table", "my_schema"))
+ is None
)
diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py
index 8d57d4ed1a8..638b377c827 100644
--- a/tests/unit_tests/db_engine_specs/test_presto.py
+++ b/tests/unit_tests/db_engine_specs/test_presto.py
@@ -24,6 +24,7 @@ from pyhive.sqlalchemy_presto import PrestoDialect
from sqlalchemy import sql, text, types
from sqlalchemy.engine.url import make_url
+from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
@@ -143,7 +144,10 @@ def test_where_latest_partition(
expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
result = spec.where_latest_partition(
- "table", mock.MagicMock(), mock.MagicMock(), query, columns
+ mock.MagicMock(),
+ Table("table"),
+ query,
+ columns,
)
assert result is not None
actual = result.compile(
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py
index 5353578850a..5bd83828ed2 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -311,15 +311,15 @@ def test_convert_dttm(
assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm)
-def test_get_extra_table_metadata() -> None:
+def test_get_extra_table_metadata(mocker: MockerFixture) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
- db_mock = Mock()
+ db_mock = mocker.MagicMock()
db_mock.get_indexes = Mock(
return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
)
db_mock.get_extra = Mock(return_value={})
- db_mock.has_view_by_name = Mock(return_value=None)
+ db_mock.has_view = Mock(return_value=None)
db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]}))
result = TrinoEngineSpec.get_extra_table_metadata(
db_mock,
@@ -442,7 +442,7 @@ def test_get_columns(mocker: MockerFixture):
mock_inspector = mocker.MagicMock()
mock_inspector.get_columns.return_value = sqla_columns
- actual = TrinoEngineSpec.get_columns(mock_inspector, "table", "schema")
+ actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema"))
expected = [
ResultSetColumnType(
name="field1", column_name="field1", type=field1_type, is_dttm=False
@@ -475,7 +475,9 @@ def test_get_columns_expand_rows(mocker: MockerFixture):
mock_inspector.get_columns.return_value = sqla_columns
actual = TrinoEngineSpec.get_columns(
- mock_inspector, "table", "schema", {"expand_rows": True}
+ mock_inspector,
+ Table("table", "schema"),
+ {"expand_rows": True},
)
expected = [
ResultSetColumnType(
@@ -538,7 +540,9 @@ def test_get_indexes_no_table():
side_effect=NoSuchTableError("The specified table does not exist.")
)
result = TrinoEngineSpec.get_indexes(
- db_mock, inspector_mock, "test_table", "test_schema"
+ db_mock,
+ inspector_mock,
+ Table("test_table", "test_schema"),
)
assert result == []
diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py
index beefd3ea3cc..ce3ad182227 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -18,7 +18,6 @@
# pylint: disable=import-outside-toplevel
import json
from datetime import datetime
-from typing import Optional
import pytest
from pytest_mock import MockFixture
@@ -26,6 +25,7 @@ from sqlalchemy.engine.reflection import Inspector
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.models.core import Database
+from superset.sql_parse import Table
def test_get_metrics(mocker: MockFixture) -> None:
@@ -37,7 +37,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
from superset.models.core import Database
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
- assert database.get_metrics("table") == [
+ assert database.get_metrics(Table("table")) == [
{
"expression": "COUNT(*)",
"metric_name": "count",
@@ -52,8 +52,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
cls,
database: Database,
inspector: Inspector,
- table_name: str,
- schema: Optional[str],
+ table: Table,
) -> list[MetricType]:
return [
{
@@ -65,7 +64,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
]
database.get_db_engine_spec = mocker.MagicMock(return_value=CustomSqliteEngineSpec)
- assert database.get_metrics("table") == [
+ assert database.get_metrics(Table("table")) == [
{
"expression": "COUNT(DISTINCT user_id)",
"metric_name": "count_distinct_user_id",