fix: Persist catalog change during dataset update + validation fixes (#33384)

This commit is contained in:
Vitor Avila
2025-05-08 15:22:25 -03:00
committed by GitHub
parent 4ed05f4ff1
commit 72cd9dffa3
6 changed files with 542 additions and 245 deletions

View File

@@ -120,71 +120,83 @@ const DatasourceModal: FunctionComponent<DatasourceModalProps> = ({
const [isEditing, setIsEditing] = useState<boolean>(false); const [isEditing, setIsEditing] = useState<boolean>(false);
const dialog = useRef<any>(null); const dialog = useRef<any>(null);
const [modal, contextHolder] = Modal.useModal(); const [modal, contextHolder] = Modal.useModal();
const buildPayload = (datasource: Record<string, any>) => ({ const buildPayload = (datasource: Record<string, any>) => {
table_name: datasource.table_name, const payload: Record<string, any> = {
database_id: datasource.database?.id, table_name: datasource.table_name,
sql: datasource.sql, database_id: datasource.database?.id,
filter_select_enabled: datasource.filter_select_enabled, sql: datasource.sql,
fetch_values_predicate: datasource.fetch_values_predicate, filter_select_enabled: datasource.filter_select_enabled,
schema: fetch_values_predicate: datasource.fetch_values_predicate,
datasource.tableSelector?.schema || schema:
datasource.databaseSelector?.schema || datasource.tableSelector?.schema ||
datasource.schema, datasource.databaseSelector?.schema ||
description: datasource.description, datasource.schema,
main_dttm_col: datasource.main_dttm_col, description: datasource.description,
normalize_columns: datasource.normalize_columns, main_dttm_col: datasource.main_dttm_col,
always_filter_main_dttm: datasource.always_filter_main_dttm, normalize_columns: datasource.normalize_columns,
offset: datasource.offset, always_filter_main_dttm: datasource.always_filter_main_dttm,
default_endpoint: datasource.default_endpoint, offset: datasource.offset,
cache_timeout: default_endpoint: datasource.default_endpoint,
datasource.cache_timeout === '' ? null : datasource.cache_timeout, cache_timeout:
is_sqllab_view: datasource.is_sqllab_view, datasource.cache_timeout === '' ? null : datasource.cache_timeout,
template_params: datasource.template_params, is_sqllab_view: datasource.is_sqllab_view,
extra: datasource.extra, template_params: datasource.template_params,
is_managed_externally: datasource.is_managed_externally, extra: datasource.extra,
external_url: datasource.external_url, is_managed_externally: datasource.is_managed_externally,
metrics: datasource?.metrics?.map((metric: DatasetObject['metrics'][0]) => { external_url: datasource.external_url,
const metricBody: any = { metrics: datasource?.metrics?.map(
expression: metric.expression, (metric: DatasetObject['metrics'][0]) => {
description: metric.description, const metricBody: any = {
metric_name: metric.metric_name, expression: metric.expression,
metric_type: metric.metric_type, description: metric.description,
d3format: metric.d3format || null, metric_name: metric.metric_name,
currency: !isDefined(metric.currency) metric_type: metric.metric_type,
? null d3format: metric.d3format || null,
: JSON.stringify(metric.currency), currency: !isDefined(metric.currency)
verbose_name: metric.verbose_name, ? null
warning_text: metric.warning_text, : JSON.stringify(metric.currency),
uuid: metric.uuid, verbose_name: metric.verbose_name,
extra: buildExtraJsonObject(metric), warning_text: metric.warning_text,
}; uuid: metric.uuid,
if (!Number.isNaN(Number(metric.id))) { extra: buildExtraJsonObject(metric),
metricBody.id = metric.id; };
} if (!Number.isNaN(Number(metric.id))) {
return metricBody; metricBody.id = metric.id;
}), }
columns: datasource?.columns?.map( return metricBody;
(column: DatasetObject['columns'][0]) => ({ },
id: typeof column.id === 'number' ? column.id : undefined, ),
column_name: column.column_name, columns: datasource?.columns?.map(
type: column.type, (column: DatasetObject['columns'][0]) => ({
advanced_data_type: column.advanced_data_type, id: typeof column.id === 'number' ? column.id : undefined,
verbose_name: column.verbose_name, column_name: column.column_name,
description: column.description, type: column.type,
expression: column.expression, advanced_data_type: column.advanced_data_type,
filterable: column.filterable, verbose_name: column.verbose_name,
groupby: column.groupby, description: column.description,
is_active: column.is_active, expression: column.expression,
is_dttm: column.is_dttm, filterable: column.filterable,
python_date_format: column.python_date_format || null, groupby: column.groupby,
uuid: column.uuid, is_active: column.is_active,
extra: buildExtraJsonObject(column), is_dttm: column.is_dttm,
}), python_date_format: column.python_date_format || null,
), uuid: column.uuid,
owners: datasource.owners.map( extra: buildExtraJsonObject(column),
(o: Record<string, number>) => o.value || o.id, }),
), ),
}); owners: datasource.owners.map(
(o: Record<string, number>) => o.value || o.id,
),
};
// Handle catalog based on database's allow_multi_catalog setting
// If multi-catalog is disabled, don't include catalog in payload
// The backend will use the default catalog
// If multi-catalog is enabled, include the selected catalog
if (datasource.database?.allow_multi_catalog) {
payload.catalog = datasource.catalog;
}
return payload;
};
const onConfirmSave = async () => { const onConfirmSave = async () => {
// Pull out extra fields into the extra object // Pull out extra fields into the extra object
setIsSaving(true); setIsSaving(true);

View File

@@ -62,6 +62,7 @@ export type DatasetObject = {
filter_select_enabled?: boolean; filter_select_enabled?: boolean;
fetch_values_predicate?: string; fetch_values_predicate?: string;
schema?: string; schema?: string;
catalog?: string;
description: string | null; description: string | null;
main_dttm_col: string; main_dttm_col: string;
offset?: number; offset?: number;

View File

@@ -33,6 +33,18 @@ def get_dataset_exist_error_msg(table: Table) -> str:
return _("Dataset %(table)s already exists", table=table) return _("Dataset %(table)s already exists", table=table)
class MultiCatalogDisabledValidationError(ValidationError):
"""
Validation error for using a non-default catalog when multi-catalog is disabled
"""
def __init__(self) -> None:
super().__init__(
[_("Only the default catalog is supported for this connection")],
field_name="catalog",
)
class DatabaseNotFoundValidationError(ValidationError): class DatabaseNotFoundValidationError(ValidationError):
""" """
Marshmallow validation error for database does not exist Marshmallow validation error for database does not exist
@@ -42,15 +54,6 @@ class DatabaseNotFoundValidationError(ValidationError):
super().__init__([_("Database does not exist")], field_name="database") super().__init__([_("Database does not exist")], field_name="database")
class DatabaseChangeValidationError(ValidationError):
"""
Marshmallow validation error database changes are not allowed on update
"""
def __init__(self) -> None:
super().__init__([_("Database not allowed to change")], field_name="database")
class DatasetExistsValidationError(ValidationError): class DatasetExistsValidationError(ValidationError):
""" """
Marshmallow validation error for dataset already exists Marshmallow validation error for dataset already exists

View File

@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from __future__ import annotations
import logging import logging
from collections import Counter from collections import Counter
from functools import partial from functools import partial
@@ -26,7 +28,7 @@ from sqlalchemy.exc import SQLAlchemyError
from superset import is_feature_enabled, security_manager from superset import is_feature_enabled, security_manager
from superset.commands.base import BaseCommand, UpdateMixin from superset.commands.base import BaseCommand, UpdateMixin
from superset.commands.dataset.exceptions import ( from superset.commands.dataset.exceptions import (
DatabaseChangeValidationError, DatabaseNotFoundValidationError,
DatasetColumnNotFoundValidationError, DatasetColumnNotFoundValidationError,
DatasetColumnsDuplicateValidationError, DatasetColumnsDuplicateValidationError,
DatasetColumnsExistsValidationError, DatasetColumnsExistsValidationError,
@@ -38,11 +40,13 @@ from superset.commands.dataset.exceptions import (
DatasetMetricsNotFoundValidationError, DatasetMetricsNotFoundValidationError,
DatasetNotFoundError, DatasetNotFoundError,
DatasetUpdateFailedError, DatasetUpdateFailedError,
MultiCatalogDisabledValidationError,
) )
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
from superset.datasets.schemas import FolderSchema from superset.datasets.schemas import FolderSchema
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.models.core import Database
from superset.sql_parse import Table from superset.sql_parse import Table
from superset.utils.decorators import on_error, transaction from superset.utils.decorators import on_error, transaction
@@ -86,38 +90,12 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
if not self._model: if not self._model:
raise DatasetNotFoundError() raise DatasetNotFoundError()
# Check ownership # Check permission to update the dataset
try: try:
security_manager.raise_for_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex raise DatasetForbiddenError() from ex
database_id = self._properties.get("database")
catalog = self._properties.get("catalog")
if not catalog:
catalog = self._properties["catalog"] = (
self._model.database.get_default_catalog()
)
table = Table(
self._properties.get("table_name"), # type: ignore
self._properties.get("schema"),
catalog,
)
# Validate uniqueness
if not DatasetDAO.validate_update_uniqueness(
self._model.database,
table,
self._model_id,
):
exceptions.append(DatasetExistsValidationError(table))
# Validate/Populate database not allowed to change
if database_id and database_id != self._model:
exceptions.append(DatabaseChangeValidationError())
# Validate/Populate owner # Validate/Populate owner
try: try:
owners = self.compute_owners( owners = self.compute_owners(
@@ -128,15 +106,68 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
except ValidationError as ex: except ValidationError as ex:
exceptions.append(ex) exceptions.append(ex)
self._validate_dataset_source(exceptions)
self._validate_semantics(exceptions) self._validate_semantics(exceptions)
if exceptions: if exceptions:
raise DatasetInvalidError(exceptions=exceptions) raise DatasetInvalidError(exceptions=exceptions)
def _validate_dataset_source(self, exceptions: list[ValidationError]) -> None:
# we know we have a valid model
self._model = cast(SqlaTable, self._model)
database_id = self._properties.pop("database_id", None)
catalog = self._properties.get("catalog")
new_db_connection: Database | None = None
if database_id and database_id != self._model.database.id:
if new_db_connection := DatasetDAO.get_database_by_id(database_id):
self._properties["database"] = new_db_connection
else:
exceptions.append(DatabaseNotFoundValidationError())
db = new_db_connection or self._model.database
default_catalog = db.get_default_catalog()
# If multi-catalog is disabled, and catalog provided is not
# the default one, fail
if (
"catalog" in self._properties
and catalog != default_catalog
and not db.allow_multi_catalog
):
exceptions.append(MultiCatalogDisabledValidationError())
# If the DB connection does not support multi-catalog,
# use the default catalog
elif not db.allow_multi_catalog:
catalog = self._properties["catalog"] = default_catalog
# Fallback to using the previous value if not provided
elif "catalog" not in self._properties:
catalog = self._model.catalog
schema = (
self._properties["schema"]
if "schema" in self._properties
else self._model.schema
)
table = Table(
self._properties.get("table_name", self._model.table_name),
schema,
catalog,
)
# Validate uniqueness
if not DatasetDAO.validate_update_uniqueness(
db,
table,
self._model_id,
):
exceptions.append(DatasetExistsValidationError(table))
def _validate_semantics(self, exceptions: list[ValidationError]) -> None: def _validate_semantics(self, exceptions: list[ValidationError]) -> None:
# we know we have a valid model # we know we have a valid model
self._model = cast(SqlaTable, self._model) self._model = cast(SqlaTable, self._model)
if columns := self._properties.get("columns"): if columns := self._properties.get("columns"):
self._validate_columns(columns, exceptions) self._validate_columns(columns, exceptions)

View File

@@ -14,11 +14,10 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Unit tests for Superset""" from __future__ import annotations
import unittest import unittest
from io import BytesIO from io import BytesIO
from typing import Optional
from unittest.mock import ANY, patch from unittest.mock import ANY, patch
from zipfile import is_zipfile, ZipFile from zipfile import is_zipfile, ZipFile
@@ -70,14 +69,26 @@ from tests.integration_tests.fixtures.importexport import (
class TestDatasetApi(SupersetTestCase): class TestDatasetApi(SupersetTestCase):
fixture_tables_names = ("ab_permission", "ab_permission_view", "ab_view_menu") fixture_tables_names = ("ab_permission", "ab_permission_view", "ab_view_menu")
fixture_virtual_table_names = ("sql_virtual_dataset_1", "sql_virtual_dataset_2") fixture_virtual_table_names = ("sql_virtual_dataset_1", "sql_virtual_dataset_2")
items_to_delete: list[SqlaTable | Database | TableColumn] = []
def setUp(self):
self.items_to_delete = []
def tearDown(self):
for item in self.items_to_delete:
db.session.delete(item)
db.session.commit()
super().tearDown()
@staticmethod @staticmethod
def insert_dataset( def insert_dataset(
table_name: str, table_name: str,
owners: list[int], owners: list[int],
database: Database, database: Database,
sql: Optional[str] = None, sql: str | None = None,
schema: Optional[str] = None, schema: str | None = None,
catalog: str | None = None,
fetch_metadata: bool = True,
) -> SqlaTable: ) -> SqlaTable:
obj_owners = list() # noqa: C408 obj_owners = list() # noqa: C408
for owner in owners: for owner in owners:
@@ -89,10 +100,12 @@ class TestDatasetApi(SupersetTestCase):
owners=obj_owners, owners=obj_owners,
database=database, database=database,
sql=sql, sql=sql,
catalog=catalog,
) )
db.session.add(table) db.session.add(table)
db.session.commit() db.session.commit()
table.fetch_metadata() if fetch_metadata:
table.fetch_metadata()
return table return table
def insert_default_dataset(self): def insert_default_dataset(self):
@@ -100,6 +113,16 @@ class TestDatasetApi(SupersetTestCase):
"ab_permission", [self.get_user("admin").id], get_main_database() "ab_permission", [self.get_user("admin").id], get_main_database()
) )
def insert_database(self, name: str, allow_multi_catalog: bool = False) -> Database:
db_connection = Database(
database_name=name,
sqlalchemy_uri=get_example_database().sqlalchemy_uri,
extra=('{"allow_multi_catalog": true}' if allow_multi_catalog else "{}"),
)
db.session.add(db_connection)
db.session.commit()
return db_connection
def get_fixture_datasets(self) -> list[SqlaTable]: def get_fixture_datasets(self) -> list[SqlaTable]:
return ( return (
db.session.query(SqlaTable) db.session.query(SqlaTable)
@@ -315,8 +338,7 @@ class TestDatasetApi(SupersetTestCase):
# revert gamma permission # revert gamma permission
gamma_role.permissions.remove(main_db_pvm) gamma_role.permissions.remove(main_db_pvm)
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_get_dataset_related_database_gamma(self): def test_get_dataset_related_database_gamma(self):
""" """
@@ -480,8 +502,7 @@ class TestDatasetApi(SupersetTestCase):
], ],
} }
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_get_dataset_render_jinja_exceptions(self): def test_get_dataset_render_jinja_exceptions(self):
""" """
@@ -547,8 +568,7 @@ class TestDatasetApi(SupersetTestCase):
== "Unable to render expression from dataset calculated column." == "Unable to render expression from dataset calculated column."
) )
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_get_dataset_distinct_schema(self): def test_get_dataset_distinct_schema(self):
""" """
@@ -618,9 +638,7 @@ class TestDatasetApi(SupersetTestCase):
}, },
) )
for dataset in datasets: self.items_to_delete = datasets
db.session.delete(dataset)
db.session.commit()
def test_get_dataset_distinct_not_allowed(self): def test_get_dataset_distinct_not_allowed(self):
""" """
@@ -647,8 +665,7 @@ class TestDatasetApi(SupersetTestCase):
assert response["count"] == 0 assert response["count"] == 0
assert response["result"] == [] assert response["result"] == []
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_get_dataset_info(self): def test_get_dataset_info(self):
""" """
@@ -722,8 +739,7 @@ class TestDatasetApi(SupersetTestCase):
) )
assert columns[0].expression == "COUNT(*)" assert columns[0].expression == "COUNT(*)"
db.session.delete(model) self.items_to_delete = [model]
db.session.commit()
def test_create_dataset_item_normalize(self): def test_create_dataset_item_normalize(self):
""" """
@@ -749,8 +765,7 @@ class TestDatasetApi(SupersetTestCase):
assert model.database_id == table_data["database"] assert model.database_id == table_data["database"]
assert model.normalize_columns is True assert model.normalize_columns is True
db.session.delete(model) self.items_to_delete = [model]
db.session.commit()
def test_create_dataset_item_gamma(self): def test_create_dataset_item_gamma(self):
""" """
@@ -791,8 +806,7 @@ class TestDatasetApi(SupersetTestCase):
model = db.session.query(SqlaTable).get(data.get("id")) model = db.session.query(SqlaTable).get(data.get("id"))
assert admin in model.owners assert admin in model.owners
assert alpha in model.owners assert alpha in model.owners
db.session.delete(model) self.items_to_delete = [model]
db.session.commit()
def test_create_dataset_item_owners_invalid(self): def test_create_dataset_item_owners_invalid(self):
""" """
@@ -839,8 +853,7 @@ class TestDatasetApi(SupersetTestCase):
model = db.session.query(SqlaTable).get(data.get("id")) model = db.session.query(SqlaTable).get(data.get("id"))
assert admin in model.owners assert admin in model.owners
assert alpha in model.owners assert alpha in model.owners
db.session.delete(model) self.items_to_delete = [model]
db.session.commit()
@unittest.skip("test is failing stochastically") @unittest.skip("test is failing stochastically")
def test_create_dataset_same_name_different_schema(self): def test_create_dataset_same_name_different_schema(self):
@@ -991,8 +1004,7 @@ class TestDatasetApi(SupersetTestCase):
model = db.session.query(SqlaTable).get(dataset.id) model = db.session.query(SqlaTable).get(dataset.id)
assert model.owners == current_owners assert model.owners == current_owners
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_clear_owner_list(self): def test_update_dataset_clear_owner_list(self):
""" """
@@ -1008,8 +1020,7 @@ class TestDatasetApi(SupersetTestCase):
model = db.session.query(SqlaTable).get(dataset.id) model = db.session.query(SqlaTable).get(dataset.id)
assert model.owners == [] assert model.owners == []
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_populate_owner(self): def test_update_dataset_populate_owner(self):
""" """
@@ -1026,8 +1037,7 @@ class TestDatasetApi(SupersetTestCase):
model = db.session.query(SqlaTable).get(dataset.id) model = db.session.query(SqlaTable).get(dataset.id)
assert model.owners == [gamma] assert model.owners == [gamma]
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_item(self): def test_update_dataset_item(self):
""" """
@@ -1045,8 +1055,7 @@ class TestDatasetApi(SupersetTestCase):
assert model.description == dataset_data["description"] assert model.description == dataset_data["description"]
assert model.owners == current_owners assert model.owners == current_owners
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_item_w_override_columns(self): def test_update_dataset_item_w_override_columns(self):
""" """
@@ -1082,8 +1091,7 @@ class TestDatasetApi(SupersetTestCase):
col.advanced_data_type for col in columns col.advanced_data_type for col in columns
] ]
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_item_w_override_columns_same_columns(self): def test_update_dataset_item_w_override_columns_same_columns(self):
""" """
@@ -1130,8 +1138,7 @@ class TestDatasetApi(SupersetTestCase):
columns = db.session.query(TableColumn).filter_by(table_id=dataset.id).all() columns = db.session.query(TableColumn).filter_by(table_id=dataset.id).all()
assert len(columns) != prev_col_len assert len(columns) != prev_col_len
assert len(columns) == 3 assert len(columns) == 3
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_create_column_and_metric(self): def test_update_dataset_create_column_and_metric(self):
""" """
@@ -1226,8 +1233,7 @@ class TestDatasetApi(SupersetTestCase):
assert metrics[1].warning_text == new_metric_data["warning_text"] assert metrics[1].warning_text == new_metric_data["warning_text"]
assert str(metrics[1].uuid) == new_metric_data["uuid"] assert str(metrics[1].uuid) == new_metric_data["uuid"]
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_delete_column(self): def test_update_dataset_delete_column(self):
""" """
@@ -1276,8 +1282,7 @@ class TestDatasetApi(SupersetTestCase):
assert columns[1].column_name == "name" assert columns[1].column_name == "name"
assert len(columns) == 2 assert len(columns) == 2
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_update_column(self): def test_update_dataset_update_column(self):
""" """
@@ -1313,8 +1318,7 @@ class TestDatasetApi(SupersetTestCase):
assert columns[0].groupby is False assert columns[0].groupby is False
assert columns[0].filterable is False assert columns[0].filterable is False
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_delete_metric(self): def test_update_dataset_delete_metric(self):
""" """
@@ -1357,8 +1361,7 @@ class TestDatasetApi(SupersetTestCase):
metrics = metrics_query.all() metrics = metrics_query.all()
assert len(metrics) == 1 assert len(metrics) == 1
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_update_column_uniqueness(self): def test_update_dataset_update_column_uniqueness(self):
""" """
@@ -1378,8 +1381,7 @@ class TestDatasetApi(SupersetTestCase):
"message": {"columns": ["One or more columns already exist"]} "message": {"columns": ["One or more columns already exist"]}
} }
assert data == expected_result assert data == expected_result
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_update_metric_uniqueness(self): def test_update_dataset_update_metric_uniqueness(self):
""" """
@@ -1399,8 +1401,7 @@ class TestDatasetApi(SupersetTestCase):
"message": {"metrics": ["One or more metrics already exist"]} "message": {"metrics": ["One or more metrics already exist"]}
} }
assert data == expected_result assert data == expected_result
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_update_column_duplicate(self): def test_update_dataset_update_column_duplicate(self):
""" """
@@ -1425,8 +1426,7 @@ class TestDatasetApi(SupersetTestCase):
"message": {"columns": ["One or more columns are duplicated"]} "message": {"columns": ["One or more columns are duplicated"]}
} }
assert data == expected_result assert data == expected_result
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_update_metric_duplicate(self): def test_update_dataset_update_metric_duplicate(self):
""" """
@@ -1451,8 +1451,7 @@ class TestDatasetApi(SupersetTestCase):
"message": {"metrics": ["One or more metrics are duplicated"]} "message": {"metrics": ["One or more metrics are duplicated"]}
} }
assert data == expected_result assert data == expected_result
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_item_gamma(self): def test_update_dataset_item_gamma(self):
""" """
@@ -1465,8 +1464,7 @@ class TestDatasetApi(SupersetTestCase):
uri = f"api/v1/dataset/{dataset.id}" uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.put(uri, json=table_data) rv = self.client.put(uri, json=table_data)
assert rv.status_code == 403 assert rv.status_code == 403
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_dataset_get_list_no_username(self): def test_dataset_get_list_no_username(self):
""" """
@@ -1491,8 +1489,7 @@ class TestDatasetApi(SupersetTestCase):
assert current_dataset["description"] == "changed_description" assert current_dataset["description"] == "changed_description"
assert "username" not in current_dataset["changed_by"].keys() assert "username" not in current_dataset["changed_by"].keys()
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_dataset_get_no_username(self): def test_dataset_get_no_username(self):
""" """
@@ -1512,8 +1509,7 @@ class TestDatasetApi(SupersetTestCase):
assert res["description"] == "changed_description" assert res["description"] == "changed_description"
assert "username" not in res["changed_by"].keys() assert "username" not in res["changed_by"].keys()
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_item_not_owned(self): def test_update_dataset_item_not_owned(self):
""" """
@@ -1526,8 +1522,7 @@ class TestDatasetApi(SupersetTestCase):
uri = f"api/v1/dataset/{dataset.id}" uri = f"api/v1/dataset/{dataset.id}"
rv = self.put_assert_metric(uri, table_data, "put") rv = self.put_assert_metric(uri, table_data, "put")
assert rv.status_code == 403 assert rv.status_code == 403
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_item_owners_invalid(self): def test_update_dataset_item_owners_invalid(self):
""" """
@@ -1540,8 +1535,7 @@ class TestDatasetApi(SupersetTestCase):
uri = f"api/v1/dataset/{dataset.id}" uri = f"api/v1/dataset/{dataset.id}"
rv = self.put_assert_metric(uri, table_data, "put") rv = self.put_assert_metric(uri, table_data, "put")
assert rv.status_code == 422 assert rv.status_code == 422
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
@patch("superset.daos.dataset.DatasetDAO.update") @patch("superset.daos.dataset.DatasetDAO.update")
def test_update_dataset_sqlalchemy_error(self, mock_dao_update): def test_update_dataset_sqlalchemy_error(self, mock_dao_update):
@@ -1560,8 +1554,7 @@ class TestDatasetApi(SupersetTestCase):
assert rv.status_code == 422 assert rv.status_code == 422
assert data == {"message": "Dataset could not be updated."} assert data == {"message": "Dataset could not be updated."}
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
@with_feature_flags(DATASET_FOLDERS=True) @with_feature_flags(DATASET_FOLDERS=True)
def test_update_dataset_add_folders(self): def test_update_dataset_add_folders(self):
@@ -1607,7 +1600,6 @@ class TestDatasetApi(SupersetTestCase):
uri = f"api/v1/dataset/{dataset.id}" uri = f"api/v1/dataset/{dataset.id}"
rv = self.put_assert_metric(uri, dataset_data, "put") rv = self.put_assert_metric(uri, dataset_data, "put")
print(rv.data.decode("utf-8"))
assert rv.status_code == 200 assert rv.status_code == 200
model = db.session.query(SqlaTable).get(dataset.id) model = db.session.query(SqlaTable).get(dataset.id)
@@ -1643,8 +1635,229 @@ class TestDatasetApi(SupersetTestCase):
}, },
] ]
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_update_dataset_change_db_connection_multi_catalog_disabled(self):
"""
Dataset API: Test changing the DB connection powering the dataset
to a connection with multi-catalog disabled.
"""
self.login(ADMIN_USERNAME)
db_connection = self.insert_database("db_connection")
new_db_connection = self.insert_database("new_db_connection")
dataset = self.insert_dataset(
table_name="test_dataset",
owners=[],
database=db_connection,
sql="select 1 as one",
schema="test_schema",
catalog="old_default_catalog",
fetch_metadata=False,
)
with patch.object(
new_db_connection, "get_default_catalog", return_value="new_default_catalog"
):
payload = {"database_id": new_db_connection.id}
uri = f"api/v1/dataset/{dataset.id}"
rv = self.put_assert_metric(uri, payload, "put")
assert rv.status_code == 200
model = db.session.query(SqlaTable).get(dataset.id)
assert model.database == new_db_connection
# Catalog should have been updated to new connection's default catalog
assert model.catalog == "new_default_catalog"
self.items_to_delete = [dataset, db_connection, new_db_connection]
def test_update_dataset_change_db_connection_multi_catalog_enabled(self):
"""
Dataset API: Test changing the DB connection powering the dataset
to a connection with multi-catalog enabled.
"""
self.login(ADMIN_USERNAME)
db_connection = self.insert_database("db_connection")
new_db_connection = self.insert_database(
"new_db_connection", allow_multi_catalog=True
)
dataset = self.insert_dataset(
table_name="test_dataset",
owners=[],
database=db_connection,
sql="select 1 as one",
schema="test_schema",
catalog="old_default_catalog",
fetch_metadata=False,
)
with patch.object(
new_db_connection, "get_default_catalog", return_value="default"
):
payload = {"database_id": new_db_connection.id}
uri = f"api/v1/dataset/{dataset.id}"
rv = self.put_assert_metric(uri, payload, "put")
assert rv.status_code == 200
model = db.session.query(SqlaTable).get(dataset.id)
assert model.database == new_db_connection
# Catalog was not changed as not provided and multi-catalog is enabled
assert model.catalog == "old_default_catalog"
self.items_to_delete = [dataset, db_connection, new_db_connection]
def test_update_dataset_change_db_connection_not_found(self):
"""
Dataset API: Test changing the DB connection powering the dataset
to an invalid DB connection.
"""
self.login(ADMIN_USERNAME)
dataset = self.insert_default_dataset()
payload = {"database_id": 1500}
uri = f"api/v1/dataset/{dataset.id}"
rv = self.put_assert_metric(uri, payload, "put")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response["message"] == {"database": ["Database does not exist"]}
self.items_to_delete = [dataset]
def test_update_dataset_change_catalog(self):
"""
Dataset API: Test changing the catalog associated with the dataset.
"""
self.login(ADMIN_USERNAME)
db_connection = self.insert_database("db_connection", allow_multi_catalog=True)
dataset = self.insert_dataset(
table_name="test_dataset",
owners=[],
database=db_connection,
sql="select 1 as one",
schema="test_schema",
catalog="test_catalog",
fetch_metadata=False,
)
with patch.object(db_connection, "get_default_catalog", return_value="default"):
payload = {"catalog": "other_catalog"}
uri = f"api/v1/dataset/{dataset.id}"
rv = self.put_assert_metric(uri, payload, "put")
assert rv.status_code == 200
model = db.session.query(SqlaTable).get(dataset.id)
assert model.catalog == "other_catalog"
self.items_to_delete = [dataset, db_connection]
def test_update_dataset_change_catalog_not_allowed(self):
"""
Dataset API: Test changing the catalog associated with the dataset fails
when multi-catalog is disabled on the DB connection.
"""
self.login(ADMIN_USERNAME)
db_connection = self.insert_database("db_connection")
dataset = self.insert_dataset(
table_name="test_dataset",
owners=[],
database=db_connection,
sql="select 1 as one",
schema="test_schema",
catalog="test_catalog",
fetch_metadata=False,
)
with patch.object(db_connection, "get_default_catalog", return_value="default"):
payload = {"catalog": "other_catalog"}
uri = f"api/v1/dataset/{dataset.id}"
rv = self.put_assert_metric(uri, payload, "put")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response["message"] == {
"catalog": ["Only the default catalog is supported for this connection"]
}
self.items_to_delete = [dataset, db_connection]
def test_update_dataset_validate_uniqueness(self):
"""
Dataset API: Test the dataset uniqueness validation takes into
consideration the new database connection.
"""
test_db = get_main_database()
if test_db.backend == "sqlite":
# Skip this test for SQLite as it doesn't support multiple
# schemas.
return
self.login(ADMIN_USERNAME)
db_connection = self.insert_database("db_connection")
new_db_connection = self.insert_database("new_db_connection")
first_schema_dataset = self.insert_dataset(
table_name="test_dataset",
owners=[],
database=db_connection,
sql="select 1 as one",
schema="first_schema",
fetch_metadata=False,
)
second_schema_dataset = self.insert_dataset(
table_name="test_dataset",
owners=[],
database=db_connection,
sql="select 1 as one",
schema="second_schema",
fetch_metadata=False,
)
new_db_conn_dataset = self.insert_dataset(
table_name="test_dataset",
owners=[],
database=new_db_connection,
sql="select 1 as one",
schema="first_schema",
fetch_metadata=False,
)
with patch.object(
db_connection,
"get_default_catalog",
return_value=None,
):
payload = {"schema": "second_schema"}
uri = f"api/v1/dataset/{first_schema_dataset.id}"
rv = self.put_assert_metric(uri, payload, "put")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response["message"] == {
"table": ["Dataset second_schema.test_dataset already exists"]
}
with patch.object(
new_db_connection,
"get_default_catalog",
return_value=None,
):
payload["database_id"] = new_db_connection.id
uri = f"api/v1/dataset/{first_schema_dataset.id}"
rv = self.put_assert_metric(uri, payload, "put")
assert rv.status_code == 200
model = db.session.query(SqlaTable).get(first_schema_dataset.id)
assert model.database == new_db_connection
assert model.schema == "second_schema"
self.items_to_delete = [
first_schema_dataset,
second_schema_dataset,
new_db_conn_dataset,
new_db_connection,
db_connection,
]
def test_delete_dataset_item(self): def test_delete_dataset_item(self):
""" """
@@ -1674,8 +1887,7 @@ class TestDatasetApi(SupersetTestCase):
uri = f"api/v1/dataset/{dataset.id}" uri = f"api/v1/dataset/{dataset.id}"
rv = self.delete_assert_metric(uri, "delete") rv = self.delete_assert_metric(uri, "delete")
assert rv.status_code == 403 assert rv.status_code == 403
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_delete_dataset_item_not_authorized(self): def test_delete_dataset_item_not_authorized(self):
""" """
@@ -1687,8 +1899,7 @@ class TestDatasetApi(SupersetTestCase):
uri = f"api/v1/dataset/{dataset.id}" uri = f"api/v1/dataset/{dataset.id}"
rv = self.client.delete(uri) rv = self.client.delete(uri)
assert rv.status_code == 403 assert rv.status_code == 403
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
@patch("superset.daos.dataset.DatasetDAO.delete") @patch("superset.daos.dataset.DatasetDAO.delete")
def test_delete_dataset_sqlalchemy_error(self, mock_dao_delete): def test_delete_dataset_sqlalchemy_error(self, mock_dao_delete):
@@ -1705,8 +1916,7 @@ class TestDatasetApi(SupersetTestCase):
data = json.loads(rv.data.decode("utf-8")) data = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422 assert rv.status_code == 422
assert data == {"message": "Datasets could not be deleted."} assert data == {"message": "Datasets could not be deleted."}
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
@pytest.mark.usefixtures("create_datasets") @pytest.mark.usefixtures("create_datasets")
def test_delete_dataset_column(self): def test_delete_dataset_column(self):
@@ -1947,8 +2157,7 @@ class TestDatasetApi(SupersetTestCase):
.filter_by(table_id=dataset.id, column_name="id") .filter_by(table_id=dataset.id, column_name="id")
.one() .one()
) )
db.session.delete(id_column) self.items_to_delete = [id_column]
db.session.commit()
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)
uri = f"api/v1/dataset/{dataset.id}/refresh" uri = f"api/v1/dataset/{dataset.id}/refresh"
@@ -1961,8 +2170,7 @@ class TestDatasetApi(SupersetTestCase):
.one() .one()
) )
assert id_column is not None assert id_column is not None
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
def test_dataset_item_refresh_not_found(self): def test_dataset_item_refresh_not_found(self):
""" """
@@ -1987,8 +2195,7 @@ class TestDatasetApi(SupersetTestCase):
rv = self.put_assert_metric(uri, {}, "refresh") rv = self.put_assert_metric(uri, {}, "refresh")
assert rv.status_code == 403 assert rv.status_code == 403
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
@unittest.skip("test is failing stochastically") @unittest.skip("test is failing stochastically")
def test_export_dataset(self): def test_export_dataset(self):
@@ -2250,8 +2457,7 @@ class TestDatasetApi(SupersetTestCase):
dataset = ( dataset = (
db.session.query(SqlaTable).filter_by(table_name="birth_names_2").one() db.session.query(SqlaTable).filter_by(table_name="birth_names_2").one()
) )
db.session.delete(dataset) self.items_to_delete = [dataset]
db.session.commit()
@patch("superset.commands.database.importers.v1.utils.add_permissions") @patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_dataset_overwrite(self, mock_add_permissions): def test_import_dataset_overwrite(self, mock_add_permissions):
@@ -2447,8 +2653,7 @@ class TestDatasetApi(SupersetTestCase):
response = json.loads(rv.data.decode("utf-8")) response = json.loads(rv.data.decode("utf-8"))
assert response.get("count") == 1 assert response.get("count") == 1
db.session.delete(table_w_certification) self.items_to_delete = [table_w_certification]
db.session.commit()
@pytest.mark.usefixtures("create_virtual_datasets") @pytest.mark.usefixtures("create_virtual_datasets")
def test_duplicate_virtual_dataset(self): def test_duplicate_virtual_dataset(self):
@@ -2473,8 +2678,7 @@ class TestDatasetApi(SupersetTestCase):
assert len(new_dataset.columns) == 2 assert len(new_dataset.columns) == 2
assert new_dataset.columns[0].column_name == "id" assert new_dataset.columns[0].column_name == "id"
assert new_dataset.columns[1].column_name == "name" assert new_dataset.columns[1].column_name == "name"
db.session.delete(new_dataset) self.items_to_delete = [new_dataset]
db.session.commit()
@pytest.mark.usefixtures("create_datasets") @pytest.mark.usefixtures("create_datasets")
def test_duplicate_physical_dataset(self): def test_duplicate_physical_dataset(self):
@@ -2604,8 +2808,7 @@ class TestDatasetApi(SupersetTestCase):
assert table.template_params == '{"param": 1}' assert table.template_params == '{"param": 1}'
assert table.normalize_columns is False assert table.normalize_columns is False
db.session.delete(table) self.items_to_delete = [table]
db.session.commit()
with examples_db.get_sqla_engine() as engine: with examples_db.get_sqla_engine() as engine:
engine.execute("DROP TABLE test_create_sqla_table_api") engine.execute("DROP TABLE test_create_sqla_table_api")

View File

@@ -15,78 +15,125 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import cast from typing import Any, cast
from unittest.mock import MagicMock
import pytest import pytest
from marshmallow import ValidationError from marshmallow import ValidationError
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from superset import db from superset.commands.dataset.exceptions import (
from superset.commands.dataset.exceptions import DatasetInvalidError DatabaseNotFoundValidationError,
DatasetExistsValidationError,
DatasetForbiddenError,
DatasetInvalidError,
DatasetNotFoundError,
MultiCatalogDisabledValidationError,
)
from superset.commands.dataset.update import UpdateDatasetCommand, validate_folders from superset.commands.dataset.update import UpdateDatasetCommand, validate_folders
from superset.connectors.sqla.models import SqlaTable from superset.commands.exceptions import OwnersNotFoundValidationError
from superset.datasets.schemas import FolderSchema from superset.datasets.schemas import FolderSchema
from superset.models.core import Database from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from tests.unit_tests.conftest import with_feature_flags from tests.unit_tests.conftest import with_feature_flags
@pytest.mark.usefixture("session") def test_update_dataset_not_found(mocker: MockerFixture) -> None:
def test_update_uniqueness_error(mocker: MockerFixture) -> None:
""" """
Test uniqueness validation in dataset update command. Test updating an unexisting ID raises a `DatasetNotFoundError`.
""" """
SqlaTable.metadata.create_all(db.session.get_bind()) mock_dataset_dao = mocker.patch("superset.commands.dataset.update.DatasetDAO")
mock_dataset_dao.find_by_id.return_value = None
# First, make sure session is clean with pytest.raises(DatasetNotFoundError):
db.session.rollback() UpdateDatasetCommand(1, {"name": "test"}).run()
try:
# Set up test data
database = Database(database_name="my_db", sqlalchemy_uri="sqlite://")
bar = SqlaTable(table_name="bar", schema="foo", database=database)
baz = SqlaTable(table_name="baz", schema="qux", database=database)
db.session.add_all([database, bar, baz])
db.session.commit()
# Set up mocks def test_update_dataset_forbidden(mocker: MockerFixture) -> None:
mock_g = mocker.patch("superset.security.manager.g") """
mock_g.user = MagicMock() Test try updating a dataset without permission raises a `DatasetForbiddenError`.
mocker.patch( """
"superset.views.base.security_manager.can_access_all_datasources", mock_dataset_dao = mocker.patch("superset.commands.dataset.update.DatasetDAO")
return_value=True, mock_dataset_dao.find_by_id.return_value = mocker.MagicMock()
)
mocker.patch(
"superset.commands.dataset.update.security_manager.raise_for_ownership",
return_value=None,
)
mocker.patch.object(UpdateDatasetCommand, "compute_owners", return_value=[])
# Run the test that should fail mocker.patch(
with pytest.raises(DatasetInvalidError): "superset.commands.dataset.update.security_manager.raise_for_ownership",
UpdateDatasetCommand( side_effect=SupersetSecurityException(
bar.id, SupersetError(
{ error_type=SupersetErrorType.MISSING_OWNERSHIP_ERROR,
"table_name": "baz", message="Sample message",
"schema": "qux", level=ErrorLevel.ERROR,
},
).run()
except Exception:
db.session.rollback()
raise
finally:
# Clean up - this will run even if the test fails
try:
db.session.query(SqlaTable).filter(
SqlaTable.table_name.in_(["bar", "baz"]),
SqlaTable.schema.in_(["foo", "qux"]),
).delete(synchronize_session=False)
db.session.query(Database).filter(Database.database_name == "my_db").delete(
synchronize_session=False
) )
db.session.commit() ),
except Exception: )
db.session.rollback()
with pytest.raises(DatasetForbiddenError):
UpdateDatasetCommand(1, {"name": "test"}).run()
@pytest.mark.parametrize(
("payload, exception, error_msg"),
[
(
{"database_id": 2},
DatabaseNotFoundValidationError,
"Database does not exist",
),
(
{"catalog": "test"},
MultiCatalogDisabledValidationError,
"Only the default catalog is supported for this connection",
),
(
{"table_name": "table", "schema": "schema"},
DatasetExistsValidationError,
"Dataset catalog.schema.table already exists",
),
(
{"owners": [1]},
OwnersNotFoundValidationError,
"Owners are invalid",
),
],
)
def test_update_validation_errors(
payload: dict[str, Any],
exception: Exception,
error_msg: str,
mocker: MockerFixture,
) -> None:
"""
Test validation errors for the `UpdateDatasetCommand`.
"""
mock_dataset_dao = mocker.patch("superset.commands.dataset.update.DatasetDAO")
mocker.patch(
"superset.commands.dataset.update.security_manager.raise_for_ownership",
)
mocker.patch("superset.commands.utils.security_manager.is_admin", return_value=True)
mocker.patch(
"superset.commands.utils.security_manager.get_user_by_id", return_value=None
)
mock_database = mocker.MagicMock()
mock_database.id = 1
mock_database.get_default_catalog.return_value = "catalog"
mock_database.allow_multi_catalog = False
mock_dataset = mocker.MagicMock()
mock_dataset.database = mock_database
mock_dataset.catalog = "catalog"
mock_dataset_dao.find_by_id.return_value = mock_dataset
if exception == DatabaseNotFoundValidationError:
mock_dataset_dao.get_database_by_id.return_value = None
else:
mock_dataset_dao.get_database_by_id.return_value = mock_database
if exception == DatasetExistsValidationError:
mock_dataset_dao.validate_update_uniqueness.return_value = False
else:
mock_dataset_dao.validate_update_uniqueness.return_value = True
with pytest.raises(DatasetInvalidError) as excinfo:
UpdateDatasetCommand(1, payload).run()
assert any(error_msg in str(exc) for exc in excinfo.value._exceptions)
@with_feature_flags(DATASET_FOLDERS=True) @with_feature_flags(DATASET_FOLDERS=True)