mirror of
https://github.com/apache/superset.git
synced 2026-04-28 12:34:23 +00:00
Compare commits
5 Commits
docs/testi
...
elizabeth/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
33ff67d74e | ||
|
|
0adf291729 | ||
|
|
4942d73a08 | ||
|
|
fec70eaf3e | ||
|
|
aec236994a |
@@ -27,7 +27,9 @@ from flask import current_app
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import db, security_manager
|
||||
# Note: Import Database functionality without importing the actual model
|
||||
from superset import db, db_engine_specs, security_manager
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.base import GenericDBException
|
||||
from superset.migrations.shared.security_converge import (
|
||||
add_pvms,
|
||||
@@ -35,13 +37,80 @@ from superset.migrations.shared.security_converge import (
|
||||
PermissionView,
|
||||
ViewMenu,
|
||||
)
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger("alembic.env")
|
||||
|
||||
Base: Type[Any] = declarative_base()
|
||||
|
||||
|
||||
class Database(Base):
|
||||
"""Local Database model for migration"""
|
||||
|
||||
__tablename__ = "dbs"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
sqlalchemy_uri = sa.Column(sa.String(1024))
|
||||
encrypted_extra = sa.Column(sa.Text)
|
||||
database_name = sa.Column(sa.String(250))
|
||||
|
||||
@property
|
||||
def db_engine_spec(self) -> Type[Any]:
|
||||
url = make_url_safe(self.sqlalchemy_uri)
|
||||
backend = url.get_backend_name()
|
||||
try:
|
||||
driver = url.get_driver_name()
|
||||
except Exception:
|
||||
driver = None
|
||||
return db_engine_specs.get_engine_spec(backend, driver)
|
||||
|
||||
def get_default_catalog(self) -> str | None:
|
||||
"""Get default catalog using the engine spec."""
|
||||
return self.db_engine_spec.get_default_catalog(self)
|
||||
|
||||
def is_oauth2_enabled(self) -> bool:
|
||||
"""Check if OAuth2 is enabled for this database."""
|
||||
from superset.utils import json
|
||||
|
||||
encrypted_extra = json.loads(self.encrypted_extra or "{}")
|
||||
return bool(encrypted_extra.get("oauth2_client_info"))
|
||||
|
||||
def get_inspector(self, catalog: str | None = None) -> Any:
|
||||
"""Get a database inspector for introspection."""
|
||||
from sqlalchemy import create_engine, inspect
|
||||
|
||||
# Create an engine from the URI
|
||||
engine = create_engine(self.sqlalchemy_uri)
|
||||
if catalog and hasattr(engine, "execution_options"):
|
||||
engine = engine.execution_options(catalog=catalog)
|
||||
return inspect(engine)
|
||||
|
||||
def get_all_schema_names(self, catalog: str | None = None) -> list[str]:
|
||||
"""
|
||||
Get all schema names for this database.
|
||||
|
||||
Uses SQLAlchemy inspector to get schema names directly.
|
||||
"""
|
||||
try:
|
||||
with self.get_inspector(catalog=catalog) as inspector:
|
||||
return self.db_engine_spec.get_schema_names(inspector)
|
||||
except Exception as ex:
|
||||
# Convert any exception to GenericDBException for consistent handling
|
||||
raise GenericDBException(str(ex)) from ex
|
||||
|
||||
def get_all_catalog_names(self) -> list[str]:
|
||||
"""
|
||||
Get all catalog names for this database.
|
||||
|
||||
Uses SQLAlchemy inspector to get catalog names directly.
|
||||
"""
|
||||
try:
|
||||
with self.get_inspector() as inspector:
|
||||
return self.db_engine_spec.get_catalog_names(self, inspector)
|
||||
except Exception as ex:
|
||||
# Convert any exception to GenericDBException for consistent handling
|
||||
raise GenericDBException(str(ex)) from ex
|
||||
|
||||
|
||||
class SqlaTable(Base):
|
||||
__tablename__ = "tables"
|
||||
|
||||
|
||||
@@ -62,36 +62,106 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
||||
|
||||
The function is called when catalogs are introduced into a new DB engine spec.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
|
||||
from superset.migrations.shared.catalogs import (
|
||||
Database,
|
||||
Query,
|
||||
SavedQuery,
|
||||
Slice,
|
||||
SqlaTable,
|
||||
TableSchema,
|
||||
TabState,
|
||||
)
|
||||
|
||||
engine = session.get_bind()
|
||||
Database.metadata.create_all(engine)
|
||||
Permission.metadata.create_all(engine)
|
||||
PermissionView.metadata.create_all(engine)
|
||||
ViewMenu.metadata.create_all(engine)
|
||||
|
||||
mocker.patch("superset.migrations.shared.catalogs.op")
|
||||
db = mocker.patch("superset.migrations.shared.catalogs.db")
|
||||
db.Session.return_value = session
|
||||
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_all_schema_names",
|
||||
return_value=["public", "information_schema"],
|
||||
)
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_all_catalog_names",
|
||||
return_value=["db", "other_catalog"],
|
||||
# Mock current_app.config to ensure we don't skip non-default catalogs
|
||||
mocker.patch.dict(
|
||||
"superset.migrations.shared.catalogs.current_app.config",
|
||||
{"CATALOGS_SIMPLIFIED_MIGRATION": False},
|
||||
)
|
||||
|
||||
# Mock the db_engine_spec methods instead of the Database model methods
|
||||
mock_db_engine_spec = mocker.MagicMock()
|
||||
mock_db_engine_spec.supports_catalog = True
|
||||
mock_db_engine_spec.get_default_catalog.return_value = "db"
|
||||
mock_db_engine_spec.get_all_schema_names.return_value = [
|
||||
"public",
|
||||
"information_schema",
|
||||
]
|
||||
mock_db_engine_spec.get_all_catalog_names.return_value = ["db", "other_catalog"]
|
||||
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"db_engine_spec",
|
||||
new_callable=mocker.PropertyMock,
|
||||
return_value=mock_db_engine_spec,
|
||||
)
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_default_catalog",
|
||||
return_value="db",
|
||||
)
|
||||
|
||||
# Create a mock database that can call the engine spec methods
|
||||
def get_all_schema_names_mock(catalog=None):
|
||||
if catalog == "other_catalog":
|
||||
return ["public", "information_schema"]
|
||||
return ["public", "information_schema"]
|
||||
|
||||
def get_all_catalog_names_mock():
|
||||
return ["db", "other_catalog"]
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
id=1,
|
||||
sqlalchemy_uri="postgresql://localhost/db",
|
||||
)
|
||||
database.database_name = "my_db"
|
||||
# Mock the methods instead of assigning
|
||||
mocker.patch.object(database, "get_all_schema_names", get_all_schema_names_mock)
|
||||
mocker.patch.object(database, "get_all_catalog_names", get_all_catalog_names_mock)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
|
||||
# Create initial permissions for testing
|
||||
db_perm = ViewMenu(name="[my_db].(id:1)")
|
||||
table_perm = ViewMenu(name="[my_db].[my_table](id:1)")
|
||||
schema_perm = ViewMenu(name="[my_db].[public]")
|
||||
|
||||
database_access = Permission(name="database_access")
|
||||
datasource_access = Permission(name="datasource_access")
|
||||
schema_access = Permission(name="schema_access")
|
||||
|
||||
session.add_all(
|
||||
[
|
||||
db_perm,
|
||||
table_perm,
|
||||
schema_perm,
|
||||
database_access,
|
||||
datasource_access,
|
||||
schema_access,
|
||||
]
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Create permission view associations
|
||||
pv1 = PermissionView(permission_id=database_access.id, view_menu_id=db_perm.id)
|
||||
pv2 = PermissionView(permission_id=datasource_access.id, view_menu_id=table_perm.id)
|
||||
pv3 = PermissionView(permission_id=schema_access.id, view_menu_id=schema_perm.id)
|
||||
session.add_all([pv1, pv2, pv3])
|
||||
session.commit()
|
||||
|
||||
dataset = SqlaTable(
|
||||
table_name="my_table",
|
||||
database=database,
|
||||
id=1,
|
||||
database_id=database.id,
|
||||
perm="[my_db].[my_table](id:1)",
|
||||
catalog=None,
|
||||
schema="public",
|
||||
catalog_perm=None,
|
||||
@@ -101,33 +171,26 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
||||
session.commit()
|
||||
|
||||
chart = Slice(
|
||||
slice_name="my_chart",
|
||||
datasource_type="table",
|
||||
datasource_id=dataset.id,
|
||||
catalog_perm=None,
|
||||
schema_perm="[my_db].[public]",
|
||||
)
|
||||
query = Query(
|
||||
client_id="foo",
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
saved_query = SavedQuery(
|
||||
database=database,
|
||||
sql="SELECT * FROM public.t",
|
||||
db_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
tab_state = TabState(
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
table_schema = TableSchema(
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
session.add_all([chart, query, saved_query, tab_state, table_schema])
|
||||
session.commit()
|
||||
@@ -158,8 +221,9 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
||||
|
||||
# add dataset/chart in new catalog
|
||||
new_dataset = SqlaTable(
|
||||
table_name="my_table",
|
||||
database=database,
|
||||
id=2,
|
||||
database_id=database.id,
|
||||
perm="[my_db].[my_table](id:2)",
|
||||
catalog="other_catalog",
|
||||
schema="public",
|
||||
schema_perm="[my_db].[other_catalog].[public]",
|
||||
@@ -168,8 +232,17 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
||||
session.add(new_dataset)
|
||||
session.commit()
|
||||
|
||||
# Add permission for the new dataset
|
||||
new_table_perm = ViewMenu(name="[my_db].[my_table](id:2)")
|
||||
session.add(new_table_perm)
|
||||
session.commit()
|
||||
pv_new = PermissionView(
|
||||
permission_id=datasource_access.id, view_menu_id=new_table_perm.id
|
||||
)
|
||||
session.add(pv_new)
|
||||
session.commit()
|
||||
|
||||
new_chart = Slice(
|
||||
slice_name="my_chart",
|
||||
datasource_type="table",
|
||||
datasource_id=new_dataset.id,
|
||||
)
|
||||
@@ -186,21 +259,24 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
||||
assert dataset.schema_perm == "[my_db].[db].[public]"
|
||||
assert chart.catalog_perm == "[my_db].[db]"
|
||||
assert chart.schema_perm == "[my_db].[db].[public]"
|
||||
assert (
|
||||
|
||||
assert sorted(
|
||||
session.query(ViewMenu.name, Permission.name)
|
||||
.join(PermissionView, ViewMenu.id == PermissionView.view_menu_id)
|
||||
.join(Permission, PermissionView.permission_id == Permission.id)
|
||||
.all()
|
||||
) == [
|
||||
("[my_db].(id:1)", "database_access"),
|
||||
("[my_db].[my_table](id:1)", "datasource_access"),
|
||||
("[my_db].[db].[public]", "schema_access"),
|
||||
("[my_db].[db]", "catalog_access"),
|
||||
("[my_db].[other_catalog]", "catalog_access"),
|
||||
("[my_db].[other_catalog].[public]", "schema_access"),
|
||||
("[my_db].[other_catalog].[information_schema]", "schema_access"),
|
||||
("[my_db].[my_table](id:2)", "datasource_access"),
|
||||
]
|
||||
) == sorted(
|
||||
[
|
||||
("[my_db].(id:1)", "database_access"),
|
||||
("[my_db].[my_table](id:1)", "datasource_access"),
|
||||
("[my_db].[db].[public]", "schema_access"),
|
||||
("[my_db].[db]", "catalog_access"),
|
||||
("[my_db].[other_catalog]", "catalog_access"),
|
||||
("[my_db].[other_catalog].[public]", "schema_access"),
|
||||
("[my_db].[other_catalog].[information_schema]", "schema_access"),
|
||||
("[my_db].[my_table](id:2)", "datasource_access"),
|
||||
]
|
||||
)
|
||||
|
||||
# do a downgrade
|
||||
downgrade_catalog_perms()
|
||||
@@ -245,32 +321,94 @@ def test_upgrade_catalog_perms_graceful(
|
||||
catalog browsing on the database (permissions are always synced on a DB update, see
|
||||
`UpdateDatabaseCommand`).
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
|
||||
from superset.migrations.shared.catalogs import (
|
||||
Database,
|
||||
Query,
|
||||
SavedQuery,
|
||||
Slice,
|
||||
SqlaTable,
|
||||
TableSchema,
|
||||
TabState,
|
||||
)
|
||||
|
||||
engine = session.get_bind()
|
||||
Database.metadata.create_all(engine)
|
||||
Permission.metadata.create_all(engine)
|
||||
PermissionView.metadata.create_all(engine)
|
||||
ViewMenu.metadata.create_all(engine)
|
||||
|
||||
mocker.patch("superset.migrations.shared.catalogs.op")
|
||||
db = mocker.patch("superset.migrations.shared.catalogs.db")
|
||||
db.Session.return_value = session
|
||||
|
||||
# Mock the db_engine_spec to support catalogs but fail on get_all_schema_names
|
||||
mock_db_engine_spec = mocker.MagicMock()
|
||||
mock_db_engine_spec.supports_catalog = True
|
||||
mock_db_engine_spec.get_default_catalog.return_value = "db"
|
||||
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_all_schema_names",
|
||||
side_effect=Exception("Failed to connect to the database"),
|
||||
"db_engine_spec",
|
||||
new_callable=mocker.PropertyMock,
|
||||
return_value=mock_db_engine_spec,
|
||||
)
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_default_catalog",
|
||||
return_value="db",
|
||||
)
|
||||
|
||||
def get_all_schema_names_mock(catalog=None):
|
||||
raise Exception("Failed to connect to the database")
|
||||
|
||||
mocker.patch("superset.migrations.shared.catalogs.op", session)
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
id=1,
|
||||
sqlalchemy_uri="postgresql://localhost/db",
|
||||
)
|
||||
database.database_name = "my_db"
|
||||
session.add(database)
|
||||
session.commit()
|
||||
|
||||
mocker.patch.object(
|
||||
database,
|
||||
"get_all_schema_names",
|
||||
side_effect=get_all_schema_names_mock,
|
||||
)
|
||||
|
||||
# Create initial permissions for testing
|
||||
db_perm = ViewMenu(name="[my_db].(id:1)")
|
||||
table_perm = ViewMenu(name="[my_db].[my_table](id:1)")
|
||||
schema_perm = ViewMenu(name="[my_db].[public]")
|
||||
|
||||
database_access = Permission(name="database_access")
|
||||
datasource_access = Permission(name="datasource_access")
|
||||
schema_access = Permission(name="schema_access")
|
||||
|
||||
session.add_all(
|
||||
[
|
||||
db_perm,
|
||||
table_perm,
|
||||
schema_perm,
|
||||
database_access,
|
||||
datasource_access,
|
||||
schema_access,
|
||||
]
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Create permission view associations
|
||||
pv1 = PermissionView(permission_id=database_access.id, view_menu_id=db_perm.id)
|
||||
pv2 = PermissionView(permission_id=datasource_access.id, view_menu_id=table_perm.id)
|
||||
pv3 = PermissionView(permission_id=schema_access.id, view_menu_id=schema_perm.id)
|
||||
session.add_all([pv1, pv2, pv3])
|
||||
session.commit()
|
||||
|
||||
dataset = SqlaTable(
|
||||
table_name="my_table",
|
||||
database=database,
|
||||
id=1,
|
||||
database_id=database.id,
|
||||
perm="[my_db].[my_table](id:1)",
|
||||
catalog=None,
|
||||
schema="public",
|
||||
schema_perm="[my_db].[public]",
|
||||
@@ -279,31 +417,26 @@ def test_upgrade_catalog_perms_graceful(
|
||||
session.commit()
|
||||
|
||||
chart = Slice(
|
||||
slice_name="my_chart",
|
||||
datasource_type="table",
|
||||
datasource_id=dataset.id,
|
||||
catalog_perm=None,
|
||||
schema_perm="[my_db].[public]",
|
||||
)
|
||||
query = Query(
|
||||
client_id="foo",
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
saved_query = SavedQuery(
|
||||
database=database,
|
||||
sql="SELECT * FROM public.t",
|
||||
db_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
tab_state = TabState(
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
table_schema = TableSchema(
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
session.add_all([chart, query, saved_query, tab_state, table_schema])
|
||||
session.commit()
|
||||
@@ -370,13 +503,21 @@ def test_upgrade_catalog_perms_oauth_connection(
|
||||
schemas. This step should be skipped if the database is set up using OAuth and not
|
||||
raise an exception.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
|
||||
from superset.migrations.shared.catalogs import (
|
||||
Database,
|
||||
Query,
|
||||
SavedQuery,
|
||||
Slice,
|
||||
SqlaTable,
|
||||
TableSchema,
|
||||
TabState,
|
||||
)
|
||||
|
||||
engine = session.get_bind()
|
||||
Database.metadata.create_all(engine)
|
||||
Permission.metadata.create_all(engine)
|
||||
PermissionView.metadata.create_all(engine)
|
||||
ViewMenu.metadata.create_all(engine)
|
||||
|
||||
mocker.patch("superset.migrations.shared.catalogs.op")
|
||||
db = mocker.patch("superset.migrations.shared.catalogs.db")
|
||||
@@ -386,14 +527,64 @@ def test_upgrade_catalog_perms_oauth_connection(
|
||||
)
|
||||
mocker.patch("superset.migrations.shared.catalogs.op", session)
|
||||
|
||||
# Mock the db_engine_spec for BigQuery with catalog support
|
||||
mock_db_engine_spec = mocker.MagicMock()
|
||||
mock_db_engine_spec.supports_catalog = True
|
||||
mock_db_engine_spec.get_default_catalog.return_value = "my-test-project"
|
||||
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"db_engine_spec",
|
||||
new_callable=mocker.PropertyMock,
|
||||
return_value=mock_db_engine_spec,
|
||||
)
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_default_catalog",
|
||||
return_value="my-test-project",
|
||||
)
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
id=1,
|
||||
sqlalchemy_uri="bigquery://my-test-project",
|
||||
encrypted_extra=json.dumps({"oauth2_client_info": oauth2_config}),
|
||||
)
|
||||
database.database_name = "my_db"
|
||||
session.add(database)
|
||||
session.commit()
|
||||
|
||||
# Create initial permissions for testing
|
||||
db_perm = ViewMenu(name="[my_db].(id:1)")
|
||||
table_perm = ViewMenu(name="[my_db].[my_table](id:1)")
|
||||
schema_perm = ViewMenu(name="[my_db].[public]")
|
||||
|
||||
database_access = Permission(name="database_access")
|
||||
datasource_access = Permission(name="datasource_access")
|
||||
schema_access = Permission(name="schema_access")
|
||||
|
||||
session.add_all(
|
||||
[
|
||||
db_perm,
|
||||
table_perm,
|
||||
schema_perm,
|
||||
database_access,
|
||||
datasource_access,
|
||||
schema_access,
|
||||
]
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Create permission view associations
|
||||
pv1 = PermissionView(permission_id=database_access.id, view_menu_id=db_perm.id)
|
||||
pv2 = PermissionView(permission_id=datasource_access.id, view_menu_id=table_perm.id)
|
||||
pv3 = PermissionView(permission_id=schema_access.id, view_menu_id=schema_perm.id)
|
||||
session.add_all([pv1, pv2, pv3])
|
||||
session.commit()
|
||||
|
||||
dataset = SqlaTable(
|
||||
table_name="my_table",
|
||||
database=database,
|
||||
id=1,
|
||||
database_id=database.id,
|
||||
perm="[my_db].[my_table](id:1)",
|
||||
catalog=None,
|
||||
schema="public",
|
||||
schema_perm="[my_db].[public]",
|
||||
@@ -402,31 +593,26 @@ def test_upgrade_catalog_perms_oauth_connection(
|
||||
session.commit()
|
||||
|
||||
chart = Slice(
|
||||
slice_name="my_chart",
|
||||
datasource_type="table",
|
||||
datasource_id=dataset.id,
|
||||
catalog_perm=None,
|
||||
schema_perm="[my_db].[public]",
|
||||
)
|
||||
query = Query(
|
||||
client_id="foo",
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
saved_query = SavedQuery(
|
||||
database=database,
|
||||
sql="SELECT * FROM public.t",
|
||||
db_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
tab_state = TabState(
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
table_schema = TableSchema(
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
session.add_all([chart, query, saved_query, tab_state, table_schema])
|
||||
session.commit()
|
||||
@@ -494,14 +680,22 @@ def test_upgrade_catalog_perms_simplified_migration(
|
||||
This should only update existing permissions + create a new permission
|
||||
for the default catalog.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
|
||||
from superset.migrations.shared.catalogs import (
|
||||
Database,
|
||||
Query,
|
||||
SavedQuery,
|
||||
Slice,
|
||||
SqlaTable,
|
||||
TableSchema,
|
||||
TabState,
|
||||
)
|
||||
from superset.migrations.shared.security_converge import Base as SecurityBase
|
||||
|
||||
engine = session.get_bind()
|
||||
Database.metadata.create_all(engine)
|
||||
|
||||
SecurityBase.metadata.create_all(engine)
|
||||
|
||||
mocker.patch("superset.migrations.shared.catalogs.op")
|
||||
db = mocker.patch("superset.migrations.shared.catalogs.db")
|
||||
db.Session.return_value = session
|
||||
@@ -510,13 +704,65 @@ def test_upgrade_catalog_perms_simplified_migration(
|
||||
)
|
||||
mocker.patch("superset.migrations.shared.catalogs.op", session)
|
||||
|
||||
# Mock the db_engine_spec for BigQuery with catalog support
|
||||
mock_db_engine_spec = mocker.MagicMock()
|
||||
mock_db_engine_spec.supports_catalog = True
|
||||
mock_db_engine_spec.get_default_catalog.return_value = "my-test-project"
|
||||
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"db_engine_spec",
|
||||
new_callable=mocker.PropertyMock,
|
||||
return_value=mock_db_engine_spec,
|
||||
)
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_default_catalog",
|
||||
return_value="my-test-project",
|
||||
)
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
id=1,
|
||||
sqlalchemy_uri="bigquery://my-test-project",
|
||||
)
|
||||
database.database_name = "my_db"
|
||||
session.add(database)
|
||||
session.commit()
|
||||
|
||||
# Create initial permissions for testing
|
||||
db_perm = ViewMenu(name="[my_db].(id:1)")
|
||||
table_perm = ViewMenu(name="[my_db].[my_table](id:1)")
|
||||
schema_perm = ViewMenu(name="[my_db].[public]")
|
||||
|
||||
database_access = Permission(name="database_access")
|
||||
datasource_access = Permission(name="datasource_access")
|
||||
schema_access = Permission(name="schema_access")
|
||||
catalog_access = Permission(name="catalog_access")
|
||||
|
||||
session.add_all(
|
||||
[
|
||||
db_perm,
|
||||
table_perm,
|
||||
schema_perm,
|
||||
database_access,
|
||||
datasource_access,
|
||||
schema_access,
|
||||
catalog_access,
|
||||
]
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Create permission view associations
|
||||
pv1 = PermissionView(permission_id=database_access.id, view_menu_id=db_perm.id)
|
||||
pv2 = PermissionView(permission_id=datasource_access.id, view_menu_id=table_perm.id)
|
||||
pv3 = PermissionView(permission_id=schema_access.id, view_menu_id=schema_perm.id)
|
||||
session.add_all([pv1, pv2, pv3])
|
||||
session.commit()
|
||||
|
||||
dataset = SqlaTable(
|
||||
table_name="my_table",
|
||||
database=database,
|
||||
id=1,
|
||||
database_id=database.id,
|
||||
perm="[my_db].[my_table](id:1)",
|
||||
catalog=None,
|
||||
schema="public",
|
||||
schema_perm="[my_db].[public]",
|
||||
@@ -525,31 +771,26 @@ def test_upgrade_catalog_perms_simplified_migration(
|
||||
session.commit()
|
||||
|
||||
chart = Slice(
|
||||
slice_name="my_chart",
|
||||
datasource_type="table",
|
||||
datasource_id=dataset.id,
|
||||
catalog_perm=None,
|
||||
schema_perm="[my_db].[public]",
|
||||
)
|
||||
query = Query(
|
||||
client_id="foo",
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
saved_query = SavedQuery(
|
||||
database=database,
|
||||
sql="SELECT * FROM public.t",
|
||||
db_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
tab_state = TabState(
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
table_schema = TableSchema(
|
||||
database=database,
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
session.add_all([chart, query, saved_query, tab_state, table_schema])
|
||||
session.commit()
|
||||
|
||||
Reference in New Issue
Block a user