fix: Ensure table uniqueness on update (#15909)

* fix: Ensure table uniqueness on update

* Update models.py

* Update slice.py

* Update datasource_tests.py

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley
2021-08-02 12:45:55 -07:00
committed by GitHub
parent 76a13dfc9a
commit c0615c55df
20 changed files with 344 additions and 274 deletions

View File

@@ -161,7 +161,7 @@ class TestRequestAccess(SupersetTestCase):
updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table_by_name("birth_names")
birth_names = self.get_table(name="birth_names")
self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name
)
@@ -190,7 +190,7 @@ class TestRequestAccess(SupersetTestCase):
"datasource_access", updated_role.permissions[1].permission.name
)
birth_names = self.get_table_by_name("birth_names")
birth_names = self.get_table(name="birth_names")
self.assertEqual(birth_names.perm, perms[2].view_menu.name)
self.assertEqual(
"datasource_access", updated_role.permissions[2].permission.name
@@ -204,7 +204,7 @@ class TestRequestAccess(SupersetTestCase):
override_me = security_manager.find_role("override_me")
override_me.permissions.append(
security_manager.find_permission_view_menu(
view_menu_name=self.get_table_by_name("energy_usage").perm,
view_menu_name=self.get_table(name="energy_usage").perm,
permission_name="datasource_access",
)
)
@@ -218,7 +218,7 @@ class TestRequestAccess(SupersetTestCase):
self.assertEqual(201, response.status_code)
updated_override_me = security_manager.find_role("override_me")
self.assertEqual(1, len(updated_override_me.permissions))
birth_names = self.get_table_by_name("birth_names")
birth_names = self.get_table(name="birth_names")
self.assertEqual(
birth_names.perm, updated_override_me.permissions[0].view_menu.name
)

View File

@@ -99,10 +99,6 @@ def post_assert_metric(
return rv
def get_table_by_name(name: str) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(table_name=name).one()
@pytest.fixture
def logged_in_admin():
"""Fixture with app context and logged in admin user."""
@@ -132,12 +128,7 @@ class SupersetTestCase(TestCase):
@staticmethod
def get_birth_names_dataset() -> SqlaTable:
example_db = get_example_database()
return (
db.session.query(SqlaTable)
.filter_by(database=example_db, table_name="birth_names")
.one()
)
return SupersetTestCase.get_table(name="birth_names")
@staticmethod
def create_user_with_roles(
@@ -254,13 +245,31 @@ class SupersetTestCase(TestCase):
return slc
@staticmethod
def get_table_by_name(name: str) -> SqlaTable:
return get_table_by_name(name)
def get_table(
name: str, database_id: Optional[int] = None, schema: Optional[str] = None
) -> SqlaTable:
return (
db.session.query(SqlaTable)
.filter_by(
database_id=database_id
or SupersetTestCase.get_database_by_name("examples").id,
schema=schema,
table_name=name,
)
.one()
)
@staticmethod
def get_database_by_id(db_id: int) -> Database:
return db.session.query(Database).filter_by(id=db_id).one()
@staticmethod
def get_database_by_name(database_name: str = "main") -> Database:
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")
@staticmethod
def get_druid_ds_by_name(name: str) -> DruidDatasource:
return db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
@@ -340,12 +349,6 @@ class SupersetTestCase(TestCase):
):
security_manager.del_permission_role(public_role, perm)
def _get_database_by_name(self, database_name="main"):
if database_name == "examples":
return get_example_database()
else:
raise ValueError("Database doesn't exist")
def run_sql(
self,
sql,
@@ -364,7 +367,7 @@ class SupersetTestCase(TestCase):
if user_name:
self.logout()
self.login(username=(user_name or "admin"))
dbid = self._get_database_by_name(database_name).id
dbid = SupersetTestCase.get_database_by_name(database_name).id
json_payload = {
"database_id": dbid,
"sql": sql,
@@ -448,7 +451,7 @@ class SupersetTestCase(TestCase):
if user_name:
self.logout()
self.login(username=(user_name if user_name else "admin"))
dbid = self._get_database_by_name(database_name).id
dbid = SupersetTestCase.get_database_by_name(database_name).id
resp = self.get_json_resp(
"/superset/validate_sql_json/",
raise_on_error=False,

View File

@@ -545,7 +545,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
"""
admin = self.get_user("admin")
gamma = self.get_user("gamma")
birth_names_table_id = SupersetTestCase.get_table_by_name("birth_names").id
birth_names_table_id = SupersetTestCase.get_table(name="birth_names").id
chart_id = self.insert_chart(
"title", [admin.id], birth_names_table_id, admin
).id

View File

@@ -221,7 +221,7 @@ def test_import_csv_explore_database(setup_csv_upload, create_csv_files):
f'CSV file "{CSV_FILENAME1}" uploaded to table "{CSV_UPLOAD_TABLE_W_EXPLORE}"'
in resp
)
table = SupersetTestCase.get_table_by_name(CSV_UPLOAD_TABLE_W_EXPLORE)
table = SupersetTestCase.get_table(name=CSV_UPLOAD_TABLE_W_EXPLORE)
assert table.database_id == utils.get_example_database().id
@@ -267,7 +267,7 @@ def test_import_csv(setup_csv_upload, create_csv_files):
)
assert success_msg_f2 in resp
table = SupersetTestCase.get_table_by_name(CSV_UPLOAD_TABLE)
table = SupersetTestCase.get_table(name=CSV_UPLOAD_TABLE)
# make sure the new column name is reflected in the table metadata
assert "d" in table.column_names

View File

@@ -35,6 +35,7 @@ def create_table_for_dashboard(
dtype: Dict[str, Any],
table_description: str = "",
fetch_values_predicate: Optional[str] = None,
schema: Optional[str] = None,
) -> SqlaTable:
df.to_sql(
table_name,
@@ -44,14 +45,17 @@ def create_table_for_dashboard(
dtype=dtype,
index=False,
method="multi",
schema=schema,
)
table_source = ConnectorRegistry.sources["table"]
table = (
db.session.query(table_source).filter_by(table_name=table_name).one_or_none()
db.session.query(table_source)
.filter_by(database_id=database.id, schema=schema, table_name=table_name)
.one_or_none()
)
if not table:
table = table_source(table_name=table_name)
table = table_source(schema=schema, table_name=table_name)
if fetch_values_predicate:
table.fetch_values_predicate = fetch_values_predicate
table.database = database

View File

@@ -63,10 +63,10 @@ class TestDatasetApi(SupersetTestCase):
@staticmethod
def insert_dataset(
table_name: str,
schema: str,
owners: List[int],
database: Database,
sql: Optional[str] = None,
schema: Optional[str] = None,
) -> SqlaTable:
obj_owners = list()
for owner in owners:
@@ -86,7 +86,7 @@ class TestDatasetApi(SupersetTestCase):
def insert_default_dataset(self):
return self.insert_dataset(
"ab_permission", "", [self.get_user("admin").id], get_main_database()
"ab_permission", [self.get_user("admin").id], get_main_database()
)
def get_fixture_datasets(self) -> List[SqlaTable]:
@@ -105,11 +105,7 @@ class TestDatasetApi(SupersetTestCase):
for table_name in self.fixture_virtual_table_names:
datasets.append(
self.insert_dataset(
table_name,
"",
[admin.id],
main_db,
"SELECT * from ab_view_menu;",
table_name, [admin.id], main_db, "SELECT * from ab_view_menu;",
)
)
yield datasets
@@ -126,9 +122,7 @@ class TestDatasetApi(SupersetTestCase):
admin = self.get_user("admin")
main_db = get_main_database()
for tables_name in self.fixture_tables_names:
datasets.append(
self.insert_dataset(tables_name, "", [admin.id], main_db)
)
datasets.append(self.insert_dataset(tables_name, [admin.id], main_db))
yield datasets
# rollback changes
@@ -270,11 +264,13 @@ class TestDatasetApi(SupersetTestCase):
datasets = []
if example_db.backend == "postgresql":
datasets.append(
self.insert_dataset("ab_permission", "public", [], get_main_database())
self.insert_dataset(
"ab_permission", [], get_main_database(), schema="public"
)
)
datasets.append(
self.insert_dataset(
"columns", "information_schema", [], get_main_database()
"columns", [], get_main_database(), schema="information_schema",
)
)
schema_values = [
@@ -921,7 +917,7 @@ class TestDatasetApi(SupersetTestCase):
dataset = self.insert_default_dataset()
self.login(username="admin")
ab_user = self.insert_dataset(
"ab_user", "", [self.get_user("admin").id], get_main_database()
"ab_user", [self.get_user("admin").id], get_main_database()
)
table_data = {"table_name": "ab_user"}
uri = f"api/v1/dataset/{dataset.id}"

View File

@@ -17,7 +17,6 @@
"""Unit tests for Superset"""
import json
from contextlib import contextmanager
from copy import deepcopy
from unittest import mock
import pytest
@@ -28,12 +27,11 @@ from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetGenericDBErrorException
from superset.models.core import Database
from superset.utils.core import get_example_database
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
)
from .base_tests import db_insert_temp_object, SupersetTestCase
from .fixtures.datasource import datasource_post
from tests.integration_tests.fixtures.datasource import get_datasource_post
@contextmanager
@@ -54,20 +52,15 @@ def create_test_table_context(database: Database):
class TestDatasource(SupersetTestCase):
def setUp(self):
self.original_attrs = {}
self.datasource = None
db.session.begin(subtransactions=True)
def tearDown(self):
if self.datasource:
for key, value in self.original_attrs.items():
setattr(self.datasource, key, value)
db.session.commit()
db.session.rollback()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_external_metadata_for_physical_table(self):
self.login(username="admin")
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
url = f"/datasource/external_metadata/table/{tbl.id}/"
resp = self.get_json_resp(url)
col_names = {o.get("name") for o in resp}
@@ -86,7 +79,7 @@ class TestDatasource(SupersetTestCase):
session.add(table)
session.commit()
table = self.get_table_by_name("dummy_sql_table")
table = self.get_table(name="dummy_sql_table")
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("name") for o in resp} == {"intcol", "strcol"}
@@ -96,7 +89,7 @@ class TestDatasource(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_external_metadata_by_name_for_physical_table(self):
self.login(username="admin")
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
# empty schema need to be represented by undefined
url = (
f"/datasource/external_metadata_by_name/table/"
@@ -119,7 +112,7 @@ class TestDatasource(SupersetTestCase):
session.add(table)
session.commit()
table = self.get_table_by_name("dummy_sql_table")
table = self.get_table(name="dummy_sql_table")
# empty schema need to be represented by undefined
url = (
f"/datasource/external_metadata_by_name/table/"
@@ -160,7 +153,7 @@ class TestDatasource(SupersetTestCase):
session.add(table)
session.commit()
table = self.get_table_by_name("dummy_sql_table_with_template_params")
table = self.get_table(name="dummy_sql_table_with_template_params")
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("name") for o in resp} == {"intcol"}
@@ -196,7 +189,7 @@ class TestDatasource(SupersetTestCase):
@mock.patch("superset.connectors.sqla.models.SqlaTable.external_metadata")
def test_external_metadata_error_return_400(self, mock_get_datasource):
self.login(username="admin")
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
url = f"/datasource/external_metadata/table/{tbl.id}/"
mock_get_datasource.side_effect = SupersetGenericDBErrorException("oops")
@@ -221,13 +214,9 @@ class TestDatasource(SupersetTestCase):
def test_save(self):
self.login(username="admin")
tbl_id = self.get_table_by_name("birth_names").id
self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session)
for key in self.datasource.export_fields:
self.original_attrs[key] = getattr(self.datasource, key)
tbl_id = self.get_table(name="birth_names").id
datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id
data = dict(data=json.dumps(datasource_post))
resp = self.get_json_resp("/datasource/save/", data)
@@ -241,25 +230,21 @@ class TestDatasource(SupersetTestCase):
else:
self.assertEqual(resp[k], datasource_post[k])
def save_datasource_from_dict(self, datasource_dict):
def save_datasource_from_dict(self, datasource_post):
data = dict(data=json.dumps(datasource_post))
resp = self.get_json_resp("/datasource/save/", data)
return resp
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_change_database(self):
self.login(username="admin")
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
tbl_id = tbl.id
db_id = tbl.database_id
datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id
self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session)
for key in self.datasource.export_fields:
self.original_attrs[key] = getattr(self.datasource, key)
new_db = self.create_fake_db()
datasource_post["database"]["id"] = new_db.id
resp = self.save_datasource_from_dict(datasource_post)
self.assertEqual(resp["database"]["id"], new_db.id)
@@ -272,15 +257,11 @@ class TestDatasource(SupersetTestCase):
def test_save_duplicate_key(self):
self.login(username="admin")
tbl_id = self.get_table_by_name("birth_names").id
self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session)
tbl_id = self.get_table(name="birth_names").id
for key in self.datasource.export_fields:
self.original_attrs[key] = getattr(self.datasource, key)
datasource_post_copy = deepcopy(datasource_post)
datasource_post_copy["id"] = tbl_id
datasource_post_copy["columns"].extend(
datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id
datasource_post["columns"].extend(
[
{
"column_name": "<new column>",
@@ -298,18 +279,15 @@ class TestDatasource(SupersetTestCase):
},
]
)
data = dict(data=json.dumps(datasource_post_copy))
data = dict(data=json.dumps(datasource_post))
resp = self.get_json_resp("/datasource/save/", data, raise_on_error=False)
self.assertIn("Duplicate column name(s): <new column>", resp["error"])
def test_get_datasource(self):
self.login(username="admin")
tbl = self.get_table_by_name("birth_names")
self.datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session)
for key in self.datasource.export_fields:
self.original_attrs[key] = getattr(self.datasource, key)
tbl = self.get_table(name="birth_names")
datasource_post = get_datasource_post()
datasource_post["id"] = tbl.id
data = dict(data=json.dumps(datasource_post))
self.get_json_resp("/datasource/save/", data)
@@ -337,7 +315,7 @@ class TestDatasource(SupersetTestCase):
app.config["DATASET_HEALTH_CHECK"] = my_check
self.login(username="admin")
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session)
assert datasource.health_check_message == "Warning message!"
app.config["DATASET_HEALTH_CHECK"] = None

View File

@@ -66,7 +66,7 @@ class TestDictImportExport(SupersetTestCase):
cls.delete_imports()
def create_table(
self, name, schema="", id=0, cols_names=[], cols_uuids=None, metric_names=[]
self, name, schema=None, id=0, cols_names=[], cols_uuids=None, metric_names=[]
):
database_name = "main"
name = "{0}{1}".format(NAME_PREFIX, name)
@@ -128,9 +128,6 @@ class TestDictImportExport(SupersetTestCase):
def get_datasource(self, datasource_id):
return db.session.query(DruidDatasource).filter_by(id=datasource_id).first()
def get_table_by_name(self, name):
return db.session.query(SqlaTable).filter_by(table_name=name).first()
def yaml_compare(self, obj_1, obj_2):
obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False)
obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False)

View File

@@ -15,138 +15,142 @@
# specific language governing permissions and limitations
# under the License.
"""Fixtures for test_datasource.py"""
datasource_post = {
"id": None,
"column_formats": {"ratio": ".2%"},
"database": {"id": 1},
"description": "Adding a DESCRip",
"default_endpoint": "",
"filter_select_enabled": True,
"name": "birth_names",
"table_name": "birth_names",
"datasource_name": "birth_names",
"type": "table",
"schema": "",
"offset": 66,
"cache_timeout": 55,
"sql": "",
"columns": [
{
"id": 504,
"column_name": "ds",
"verbose_name": "",
"description": None,
"expression": "",
"filterable": True,
"groupby": True,
"is_dttm": True,
"type": "DATETIME",
},
{
"id": 505,
"column_name": "gender",
"verbose_name": None,
"description": None,
"expression": "",
"filterable": True,
"groupby": True,
"is_dttm": False,
"type": "VARCHAR(16)",
},
{
"id": 506,
"column_name": "name",
"verbose_name": None,
"description": None,
"expression": None,
"filterable": True,
"groupby": True,
"is_dttm": None,
"type": "VARCHAR(255)",
},
{
"id": 508,
"column_name": "state",
"verbose_name": None,
"description": None,
"expression": None,
"filterable": True,
"groupby": True,
"is_dttm": None,
"type": "VARCHAR(10)",
},
{
"id": 509,
"column_name": "num_boys",
"verbose_name": None,
"description": None,
"expression": None,
"filterable": True,
"groupby": True,
"is_dttm": None,
"type": "BIGINT(20)",
},
{
"id": 510,
"column_name": "num_girls",
"verbose_name": None,
"description": None,
"expression": "",
"filterable": False,
"groupby": False,
"is_dttm": False,
"type": "BIGINT(20)",
},
{
"id": 532,
"column_name": "num",
"verbose_name": None,
"description": None,
"expression": None,
"filterable": True,
"groupby": True,
"is_dttm": None,
"type": "BIGINT(20)",
},
{
"id": 522,
"column_name": "num_california",
"verbose_name": None,
"description": None,
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
"filterable": False,
"groupby": False,
"is_dttm": False,
"type": "NUMBER",
},
],
"metrics": [
{
"id": 824,
"metric_name": "sum__num",
"verbose_name": "Babies",
"description": "",
"expression": "SUM(num)",
"warning_text": "",
"d3format": "",
},
{
"id": 836,
"metric_name": "count",
"verbose_name": "",
"description": None,
"expression": "count(1)",
"warning_text": None,
"d3format": None,
},
{
"id": 843,
"metric_name": "ratio",
"verbose_name": "Ratio Boys/Girls",
"description": "This represents the ratio of boys/girls",
"expression": "sum(num_boys) / sum(num_girls)",
"warning_text": "no warning",
"d3format": ".2%",
},
],
}
from typing import Any, Dict
def get_datasource_post() -> Dict[str, Any]:
return {
"id": None,
"column_formats": {"ratio": ".2%"},
"database": {"id": 1},
"description": "Adding a DESCRip",
"default_endpoint": "",
"filter_select_enabled": True,
"name": "birth_names",
"table_name": "birth_names",
"datasource_name": "birth_names",
"type": "table",
"schema": None,
"offset": 66,
"cache_timeout": 55,
"sql": "",
"columns": [
{
"id": 504,
"column_name": "ds",
"verbose_name": "",
"description": None,
"expression": "",
"filterable": True,
"groupby": True,
"is_dttm": True,
"type": "DATETIME",
},
{
"id": 505,
"column_name": "gender",
"verbose_name": None,
"description": None,
"expression": "",
"filterable": True,
"groupby": True,
"is_dttm": False,
"type": "VARCHAR(16)",
},
{
"id": 506,
"column_name": "name",
"verbose_name": None,
"description": None,
"expression": None,
"filterable": True,
"groupby": True,
"is_dttm": None,
"type": "VARCHAR(255)",
},
{
"id": 508,
"column_name": "state",
"verbose_name": None,
"description": None,
"expression": None,
"filterable": True,
"groupby": True,
"is_dttm": None,
"type": "VARCHAR(10)",
},
{
"id": 509,
"column_name": "num_boys",
"verbose_name": None,
"description": None,
"expression": None,
"filterable": True,
"groupby": True,
"is_dttm": None,
"type": "BIGINT(20)",
},
{
"id": 510,
"column_name": "num_girls",
"verbose_name": None,
"description": None,
"expression": "",
"filterable": False,
"groupby": False,
"is_dttm": False,
"type": "BIGINT(20)",
},
{
"id": 532,
"column_name": "num",
"verbose_name": None,
"description": None,
"expression": None,
"filterable": True,
"groupby": True,
"is_dttm": None,
"type": "BIGINT(20)",
},
{
"id": 522,
"column_name": "num_california",
"verbose_name": None,
"description": None,
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
"filterable": False,
"groupby": False,
"is_dttm": False,
"type": "NUMBER",
},
],
"metrics": [
{
"id": 824,
"metric_name": "sum__num",
"verbose_name": "Babies",
"description": "",
"expression": "SUM(num)",
"warning_text": "",
"d3format": "",
},
{
"id": 836,
"metric_name": "count",
"verbose_name": "",
"description": None,
"expression": "count(1)",
"warning_text": None,
"d3format": None,
},
{
"id": 843,
"metric_name": "ratio",
"verbose_name": "Ratio Boys/Girls",
"description": "This represents the ratio of boys/girls",
"expression": "sum(num_boys) / sum(num_girls)",
"warning_text": "no warning",
"d3format": ".2%",
},
],
}

View File

@@ -18,7 +18,7 @@ import copy
from typing import Any, Dict, List
from superset.utils.core import AnnotationType, DTTM_ALIAS, TimeRangeEndpoint
from tests.integration_tests.base_tests import get_table_by_name
from tests.integration_tests.base_tests import SupersetTestCase
query_birth_names = {
"extras": {
@@ -245,7 +245,7 @@ def get_query_context(
:return: Request payload
"""
table_name = query_name.split(":")[0]
table = get_table_by_name(table_name)
table = SupersetTestCase.get_table(name=table_name)
return {
"datasource": {"id": table.id, "type": table.type},
"queries": [

View File

@@ -89,19 +89,20 @@ class TestImportExport(SupersetTestCase):
id=None,
db_name="examples",
table_name="wb_health_population",
schema=None,
):
params = {
"num_period_compare": "10",
"remote_id": id,
"datasource_name": table_name,
"database_name": db_name,
"schema": "",
"schema": schema,
# Test for trailing commas
"metrics": ["sum__signup_attempt_email", "sum__signup_attempt_facebook"],
}
if table_name and not ds_id:
table = self.get_table_by_name(table_name)
table = self.get_table(schema=schema, name=table_name)
if table:
ds_id = table.id
@@ -167,9 +168,6 @@ class TestImportExport(SupersetTestCase):
def get_datasource(self, datasource_id):
return db.session.query(DruidDatasource).filter_by(id=datasource_id).first()
def get_table_by_name(self, name):
return db.session.query(SqlaTable).filter_by(table_name=name).first()
def assert_dash_equals(
self, expected_dash, actual_dash, check_position=True, check_slugs=True
):
@@ -273,9 +271,7 @@ class TestImportExport(SupersetTestCase):
resp.data.decode("utf-8"), object_hook=decode_dashboards
)["datasources"]
self.assertEqual(1, len(exported_tables))
self.assert_table_equals(
self.get_table_by_name("birth_names"), exported_tables[0]
)
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
@pytest.mark.usefixtures(
"load_world_bank_dashboard_with_slices",
@@ -314,11 +310,9 @@ class TestImportExport(SupersetTestCase):
resp_data.get("datasources"), key=lambda t: t.table_name
)
self.assertEqual(2, len(exported_tables))
self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0])
self.assert_table_equals(
self.get_table_by_name("birth_names"), exported_tables[0]
)
self.assert_table_equals(
self.get_table_by_name("wb_health_population"), exported_tables[1]
self.get_table(name="wb_health_population"), exported_tables[1]
)
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
@@ -329,12 +323,12 @@ class TestImportExport(SupersetTestCase):
self.assertEqual(slc.datasource.perm, slc.perm)
self.assert_slice_equals(expected_slice, slc)
table_id = self.get_table_by_name("wb_health_population").id
table_id = self.get_table(name="wb_health_population").id
self.assertEqual(table_id, self.get_slice(slc_id).datasource_id)
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_import_2_slices_for_same_table(self):
table_id = self.get_table_by_name("wb_health_population").id
table_id = self.get_table(name="wb_health_population").id
# table_id != 666, import func will have to find the table
slc_1 = self.create_slice("Import Me 1", ds_id=666, id=10002)
slc_id_1 = import_chart(slc_1, None)
@@ -351,13 +345,6 @@ class TestImportExport(SupersetTestCase):
self.assert_slice_equals(slc_2, imported_slc_2)
self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm)
def test_import_slices_for_non_existent_table(self):
with self.assertRaises(AttributeError):
import_chart(
self.create_slice("Import Me 3", id=10004, table_name="non_existent"),
None,
)
def test_import_slices_override(self):
slc = self.create_slice("Import Me New", id=10005)
slc_1_id = import_chart(slc, None, import_time=1990)

View File

@@ -339,7 +339,7 @@ class TestDatabaseModel(SupersetTestCase):
class TestSqlaTableModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_timestamp_expression(self):
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
ds_col = tbl.get_column("ds")
sqla_literal = ds_col.get_timestamp_expression(None)
self.assertEqual(str(sqla_literal.compile()), "ds")
@@ -359,7 +359,7 @@ class TestSqlaTableModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_timestamp_expression_epoch(self):
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
ds_col = tbl.get_column("ds")
ds_col.expression = None
@@ -384,7 +384,7 @@ class TestSqlaTableModel(SupersetTestCase):
ds_col.expression = prev_ds_expr
def query_with_expr_helper(self, is_timeseries, inner_join=True):
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
ds_col = tbl.get_column("ds")
ds_col.expression = None
ds_col.python_date_format = None
@@ -447,7 +447,7 @@ class TestSqlaTableModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_sql_mutator(self):
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
query_obj = dict(
groupby=[],
metrics=None,
@@ -472,7 +472,7 @@ class TestSqlaTableModel(SupersetTestCase):
app.config["SQL_QUERY_MUTATOR"] = None
def test_query_with_non_existent_metrics(self):
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
query_obj = dict(
groupby=[],
@@ -493,7 +493,7 @@ class TestSqlaTableModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_data_for_slices(self):
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
slc = (
metadata_db.session.query(Slice)
.filter_by(

View File

@@ -92,7 +92,7 @@ class TestQueryContext(SupersetTestCase):
def test_cache(self):
table_name = "birth_names"
table = self.get_table_by_name(table_name)
table = self.get_table(name=table_name)
payload = get_query_context(table.name, table.id)
payload["force"] = True

View File

@@ -1151,7 +1151,7 @@ class TestRowLevelSecurity(SupersetTestCase):
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_rls_filter_alters_energy_query(self):
g.user = self.get_user(username="alpha")
tbl = self.get_table_by_name("energy_usage")
tbl = self.get_table(name="energy_usage")
sql = tbl.get_query_str(self.query_obj)
assert tbl.get_extra_cache_keys(self.query_obj) == [1]
assert "value > 1" in sql
@@ -1161,7 +1161,7 @@ class TestRowLevelSecurity(SupersetTestCase):
g.user = self.get_user(
username="admin"
) # self.login() doesn't actually set the user
tbl = self.get_table_by_name("energy_usage")
tbl = self.get_table(name="energy_usage")
sql = tbl.get_query_str(self.query_obj)
assert tbl.get_extra_cache_keys(self.query_obj) == []
assert "value > 1" not in sql
@@ -1171,7 +1171,7 @@ class TestRowLevelSecurity(SupersetTestCase):
g.user = self.get_user(
username="alpha"
) # self.login() doesn't actually set the user
tbl = self.get_table_by_name("unicode_test")
tbl = self.get_table(name="unicode_test")
sql = tbl.get_query_str(self.query_obj)
assert tbl.get_extra_cache_keys(self.query_obj) == [1]
assert "value > 1" in sql
@@ -1179,7 +1179,7 @@ class TestRowLevelSecurity(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_alters_gamma_birth_names_query(self):
g.user = self.get_user(username="gamma")
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)
# establish that the filters are grouped together correctly with
@@ -1192,7 +1192,7 @@ class TestRowLevelSecurity(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_alters_no_role_user_birth_names_query(self):
g.user = self.get_user(username="NoRlsRoleUser")
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)
# gamma's filters should not be present query
@@ -1205,7 +1205,7 @@ class TestRowLevelSecurity(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_doesnt_alter_admin_birth_names_query(self):
g.user = self.get_user(username="admin")
tbl = self.get_table_by_name("birth_names")
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)
# no filters are applied for admin user

View File

@@ -241,7 +241,7 @@ class TestDatabaseModel(SupersetTestCase):
FilterTestCase(FilterOperator.IN, ["1", "2"], "IN (1, 2)"),
FilterTestCase(FilterOperator.NOT_IN, ["1", "2"], "NOT IN (1, 2)"),
)
table = self.get_table_by_name("birth_names")
table = self.get_table(name="birth_names")
for filter_ in filters:
query_obj = {
"granularity": None,

View File

@@ -42,11 +42,6 @@ from tests.integration_tests.fixtures.query_context import get_query_context
from tests.integration_tests.test_app import app
def get_table_by_name(name: str) -> SqlaTable:
with app.app_context():
return db.session.query(SqlaTable).filter_by(table_name=name).one()
class TestAsyncQueries(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
@@ -127,7 +122,7 @@ class TestAsyncQueries(SupersetTestCase):
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job):
async_query_manager.init_app(app)
table = get_table_by_name("birth_names")
table = self.get_table(name="birth_names")
user = security_manager.find_user("gamma")
form_data = {
"datasource": f"{table.id}__table",