fix: add mutator to get_columns_description (#29885)

This commit is contained in:
Elizabeth Thompson
2024-08-08 17:25:33 -07:00
committed by GitHub
parent fb6efb9e9a
commit 38d64e8dd2
3 changed files with 162 additions and 61 deletions

View File

@@ -68,6 +68,24 @@ def create_test_table_context(database: Database):
engine.execute(f"DROP TABLE {full_table_name}")
@contextmanager
def create_and_cleanup_table(table=None):
if table is None:
table = SqlaTable(
table_name="dummy_sql_table",
database=get_example_database(),
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
db.session.add(table)
db.session.commit()
try:
yield table
finally:
db.session.delete(table)
db.session.commit()
class TestDatasource(SupersetTestCase):
def setUp(self):
db.session.begin(subtransactions=True)
@@ -123,37 +141,22 @@ class TestDatasource(SupersetTestCase):
sql=sql,
)
db.session.add(table)
db.session.commit()
with create_and_cleanup_table(table):
table.always_filter_main_dttm = False
result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
assert "default_dttm" not in result and "additional_dttm" in result
table.always_filter_main_dttm = False
result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
assert "default_dttm" not in result and "additional_dttm" in result
table.always_filter_main_dttm = True
result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
assert "default_dttm" in result and "additional_dttm" in result
db.session.delete(table)
db.session.commit()
table.always_filter_main_dttm = True
result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
assert "default_dttm" in result and "additional_dttm" in result
def test_external_metadata_for_virtual_table(self):
self.login(ADMIN_USERNAME)
table = SqlaTable(
table_name="dummy_sql_table",
database=get_example_database(),
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
db.session.add(table)
db.session.commit()
table = self.get_table(name="dummy_sql_table")
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
db.session.delete(table)
db.session.commit()
with create_and_cleanup_table() as table:
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_external_metadata_by_name_for_physical_table(self):
@@ -178,31 +181,42 @@ class TestDatasource(SupersetTestCase):
def test_external_metadata_by_name_for_virtual_table(self):
self.login(ADMIN_USERNAME)
table = SqlaTable(
table_name="dummy_sql_table",
database=get_example_database(),
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
db.session.add(table)
db.session.commit()
with create_and_cleanup_table() as tbl:
params = prison.dumps(
{
"datasource_type": "table",
"database_name": tbl.database.database_name,
"schema_name": tbl.schema,
"table_name": tbl.table_name,
"normalize_columns": tbl.normalize_columns,
"always_filter_main_dttm": tbl.always_filter_main_dttm,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
tbl = self.get_table(name="dummy_sql_table")
params = prison.dumps(
{
"datasource_type": "table",
"database_name": tbl.database.database_name,
"schema_name": tbl.schema,
"table_name": tbl.table_name,
"normalize_columns": tbl.normalize_columns,
"always_filter_main_dttm": tbl.always_filter_main_dttm,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
db.session.delete(tbl)
db.session.commit()
def test_external_metadata_by_name_for_virtual_table_uses_mutator(self):
self.login(ADMIN_USERNAME)
with create_and_cleanup_table() as tbl:
app.config["SQL_QUERY_MUTATOR"] = (
lambda sql, **kwargs: "SELECT 456 as intcol, 'def' as mutated_strcol"
)
params = prison.dumps(
{
"datasource_type": "table",
"database_name": tbl.database.database_name,
"schema_name": tbl.schema,
"table_name": tbl.table_name,
"normalize_columns": tbl.normalize_columns,
"always_filter_main_dttm": tbl.always_filter_main_dttm,
}
)
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "mutated_strcol"}
app.config["SQL_QUERY_MUTATOR"] = None
def test_external_metadata_by_name_from_sqla_inspector(self):
self.login(ADMIN_USERNAME)
@@ -278,15 +292,10 @@ class TestDatasource(SupersetTestCase):
sql="select {{ foo }} as intcol",
template_params=json.dumps({"foo": "123"}),
)
db.session.add(table)
db.session.commit()
table = self.get_table(name="dummy_sql_table_with_template_params")
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol"}
db.session.delete(table)
db.session.commit()
with create_and_cleanup_table(table) as tbl:
url = f"/datasource/external_metadata/table/{tbl.id}/"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol"}
def test_external_metadata_for_malicious_virtual_table(self):
self.login(ADMIN_USERNAME)