Compare commits

...

5 Commits

Author SHA1 Message Date
Elizabeth Thompson
33ff67d74e Apply suggestion from @bito-code-review[bot]
Co-authored-by: bito-code-review[bot] <188872107+bito-code-review[bot]@users.noreply.github.com>
2025-09-05 14:42:38 -07:00
Elizabeth Thompson
0adf291729 Apply suggestion from @bito-code-review[bot]
Co-authored-by: bito-code-review[bot] <188872107+bito-code-review[bot]@users.noreply.github.com>
2025-09-05 14:35:00 -07:00
Elizabeth Thompson
4942d73a08 Update tests/unit_tests/migrations/shared/catalogs_test.py
Co-authored-by: bito-code-review[bot] <188872107+bito-code-review[bot]@users.noreply.github.com>
2025-09-05 14:34:11 -07:00
Elizabeth Thompson
fec70eaf3e fix tests 2025-09-05 13:05:10 -07:00
Elizabeth Thompson
aec236994a use local db model in migration 2025-08-27 17:22:39 -07:00
2 changed files with 409 additions and 99 deletions

View File

@@ -27,7 +27,9 @@ from flask import current_app
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from superset import db, security_manager
# Note: Import Database functionality without importing the actual model
from superset import db, db_engine_specs, security_manager
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import GenericDBException
from superset.migrations.shared.security_converge import (
add_pvms,
@@ -35,13 +37,80 @@ from superset.migrations.shared.security_converge import (
PermissionView,
ViewMenu,
)
from superset.models.core import Database
logger = logging.getLogger("alembic.env")
Base: Type[Any] = declarative_base()
class Database(Base):
"""Local Database model for migration"""
__tablename__ = "dbs"
id = sa.Column(sa.Integer, primary_key=True)
sqlalchemy_uri = sa.Column(sa.String(1024))
encrypted_extra = sa.Column(sa.Text)
database_name = sa.Column(sa.String(250))
@property
def db_engine_spec(self) -> Type[Any]:
url = make_url_safe(self.sqlalchemy_uri)
backend = url.get_backend_name()
try:
driver = url.get_driver_name()
except Exception:
driver = None
return db_engine_specs.get_engine_spec(backend, driver)
def get_default_catalog(self) -> str | None:
"""Get default catalog using the engine spec."""
return self.db_engine_spec.get_default_catalog(self)
def is_oauth2_enabled(self) -> bool:
"""Check if OAuth2 is enabled for this database."""
from superset.utils import json
encrypted_extra = json.loads(self.encrypted_extra or "{}")
return bool(encrypted_extra.get("oauth2_client_info"))
def get_inspector(self, catalog: str | None = None) -> Any:
"""Get a database inspector for introspection."""
from sqlalchemy import create_engine, inspect
# Create an engine from the URI
engine = create_engine(self.sqlalchemy_uri)
if catalog and hasattr(engine, "execution_options"):
engine = engine.execution_options(catalog=catalog)
return inspect(engine)
def get_all_schema_names(self, catalog: str | None = None) -> list[str]:
"""
Get all schema names for this database.
Uses SQLAlchemy inspector to get schema names directly.
"""
try:
with self.get_inspector(catalog=catalog) as inspector:
return self.db_engine_spec.get_schema_names(inspector)
except Exception as ex:
# Convert any exception to GenericDBException for consistent handling
raise GenericDBException(str(ex)) from ex
def get_all_catalog_names(self) -> list[str]:
"""
Get all catalog names for this database.
Uses SQLAlchemy inspector to get catalog names directly.
"""
try:
with self.get_inspector() as inspector:
return self.db_engine_spec.get_catalog_names(self, inspector)
except Exception as ex:
# Convert any exception to GenericDBException for consistent handling
raise GenericDBException(str(ex)) from ex
class SqlaTable(Base):
__tablename__ = "tables"

View File

@@ -62,36 +62,106 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
The function is called when catalogs are introduced into a new DB engine spec.
"""
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
from superset.migrations.shared.catalogs import (
Database,
Query,
SavedQuery,
Slice,
SqlaTable,
TableSchema,
TabState,
)
engine = session.get_bind()
Database.metadata.create_all(engine)
Permission.metadata.create_all(engine)
PermissionView.metadata.create_all(engine)
ViewMenu.metadata.create_all(engine)
mocker.patch("superset.migrations.shared.catalogs.op")
db = mocker.patch("superset.migrations.shared.catalogs.db")
db.Session.return_value = session
mocker.patch.object(
Database,
"get_all_schema_names",
return_value=["public", "information_schema"],
)
mocker.patch.object(
Database,
"get_all_catalog_names",
return_value=["db", "other_catalog"],
# Mock current_app.config to ensure we don't skip non-default catalogs
mocker.patch.dict(
"superset.migrations.shared.catalogs.current_app.config",
{"CATALOGS_SIMPLIFIED_MIGRATION": False},
)
# Mock the db_engine_spec methods instead of the Database model methods
mock_db_engine_spec = mocker.MagicMock()
mock_db_engine_spec.supports_catalog = True
mock_db_engine_spec.get_default_catalog.return_value = "db"
mock_db_engine_spec.get_all_schema_names.return_value = [
"public",
"information_schema",
]
mock_db_engine_spec.get_all_catalog_names.return_value = ["db", "other_catalog"]
mocker.patch.object(
Database,
"db_engine_spec",
new_callable=mocker.PropertyMock,
return_value=mock_db_engine_spec,
)
mocker.patch.object(
Database,
"get_default_catalog",
return_value="db",
)
# Create a mock database that can call the engine spec methods
def get_all_schema_names_mock(catalog=None):
if catalog == "other_catalog":
return ["public", "information_schema"]
return ["public", "information_schema"]
def get_all_catalog_names_mock():
return ["db", "other_catalog"]
database = Database(
database_name="my_db",
id=1,
sqlalchemy_uri="postgresql://localhost/db",
)
database.database_name = "my_db"
# Mock the methods instead of assigning
mocker.patch.object(database, "get_all_schema_names", get_all_schema_names_mock)
mocker.patch.object(database, "get_all_catalog_names", get_all_catalog_names_mock)
session.add(database)
session.commit()
# Create initial permissions for testing
db_perm = ViewMenu(name="[my_db].(id:1)")
table_perm = ViewMenu(name="[my_db].[my_table](id:1)")
schema_perm = ViewMenu(name="[my_db].[public]")
database_access = Permission(name="database_access")
datasource_access = Permission(name="datasource_access")
schema_access = Permission(name="schema_access")
session.add_all(
[
db_perm,
table_perm,
schema_perm,
database_access,
datasource_access,
schema_access,
]
)
session.commit()
# Create permission view associations
pv1 = PermissionView(permission_id=database_access.id, view_menu_id=db_perm.id)
pv2 = PermissionView(permission_id=datasource_access.id, view_menu_id=table_perm.id)
pv3 = PermissionView(permission_id=schema_access.id, view_menu_id=schema_perm.id)
session.add_all([pv1, pv2, pv3])
session.commit()
dataset = SqlaTable(
table_name="my_table",
database=database,
id=1,
database_id=database.id,
perm="[my_db].[my_table](id:1)",
catalog=None,
schema="public",
catalog_perm=None,
@@ -101,33 +171,26 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
session.commit()
chart = Slice(
slice_name="my_chart",
datasource_type="table",
datasource_id=dataset.id,
catalog_perm=None,
schema_perm="[my_db].[public]",
)
query = Query(
client_id="foo",
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
saved_query = SavedQuery(
database=database,
sql="SELECT * FROM public.t",
db_id=database.id,
catalog=None,
schema="public",
)
tab_state = TabState(
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
table_schema = TableSchema(
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
session.add_all([chart, query, saved_query, tab_state, table_schema])
session.commit()
@@ -158,8 +221,9 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
# add dataset/chart in new catalog
new_dataset = SqlaTable(
table_name="my_table",
database=database,
id=2,
database_id=database.id,
perm="[my_db].[my_table](id:2)",
catalog="other_catalog",
schema="public",
schema_perm="[my_db].[other_catalog].[public]",
@@ -168,8 +232,17 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
session.add(new_dataset)
session.commit()
# Add permission for the new dataset
new_table_perm = ViewMenu(name="[my_db].[my_table](id:2)")
session.add(new_table_perm)
session.commit()
pv_new = PermissionView(
permission_id=datasource_access.id, view_menu_id=new_table_perm.id
)
session.add(pv_new)
session.commit()
new_chart = Slice(
slice_name="my_chart",
datasource_type="table",
datasource_id=new_dataset.id,
)
@@ -186,21 +259,24 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
assert dataset.schema_perm == "[my_db].[db].[public]"
assert chart.catalog_perm == "[my_db].[db]"
assert chart.schema_perm == "[my_db].[db].[public]"
assert (
assert sorted(
session.query(ViewMenu.name, Permission.name)
.join(PermissionView, ViewMenu.id == PermissionView.view_menu_id)
.join(Permission, PermissionView.permission_id == Permission.id)
.all()
) == [
("[my_db].(id:1)", "database_access"),
("[my_db].[my_table](id:1)", "datasource_access"),
("[my_db].[db].[public]", "schema_access"),
("[my_db].[db]", "catalog_access"),
("[my_db].[other_catalog]", "catalog_access"),
("[my_db].[other_catalog].[public]", "schema_access"),
("[my_db].[other_catalog].[information_schema]", "schema_access"),
("[my_db].[my_table](id:2)", "datasource_access"),
]
) == sorted(
[
("[my_db].(id:1)", "database_access"),
("[my_db].[my_table](id:1)", "datasource_access"),
("[my_db].[db].[public]", "schema_access"),
("[my_db].[db]", "catalog_access"),
("[my_db].[other_catalog]", "catalog_access"),
("[my_db].[other_catalog].[public]", "schema_access"),
("[my_db].[other_catalog].[information_schema]", "schema_access"),
("[my_db].[my_table](id:2)", "datasource_access"),
]
)
# do a downgrade
downgrade_catalog_perms()
@@ -245,32 +321,94 @@ def test_upgrade_catalog_perms_graceful(
catalog browsing on the database (permissions are always synced on a DB update, see
`UpdateDatabaseCommand`).
"""
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
from superset.migrations.shared.catalogs import (
Database,
Query,
SavedQuery,
Slice,
SqlaTable,
TableSchema,
TabState,
)
engine = session.get_bind()
Database.metadata.create_all(engine)
Permission.metadata.create_all(engine)
PermissionView.metadata.create_all(engine)
ViewMenu.metadata.create_all(engine)
mocker.patch("superset.migrations.shared.catalogs.op")
db = mocker.patch("superset.migrations.shared.catalogs.db")
db.Session.return_value = session
# Mock the db_engine_spec to support catalogs but fail on get_all_schema_names
mock_db_engine_spec = mocker.MagicMock()
mock_db_engine_spec.supports_catalog = True
mock_db_engine_spec.get_default_catalog.return_value = "db"
mocker.patch.object(
Database,
"get_all_schema_names",
side_effect=Exception("Failed to connect to the database"),
"db_engine_spec",
new_callable=mocker.PropertyMock,
return_value=mock_db_engine_spec,
)
mocker.patch.object(
Database,
"get_default_catalog",
return_value="db",
)
def get_all_schema_names_mock(catalog=None):
raise Exception("Failed to connect to the database")
mocker.patch("superset.migrations.shared.catalogs.op", session)
database = Database(
database_name="my_db",
id=1,
sqlalchemy_uri="postgresql://localhost/db",
)
database.database_name = "my_db"
session.add(database)
session.commit()
mocker.patch.object(
database,
"get_all_schema_names",
side_effect=get_all_schema_names_mock,
)
# Create initial permissions for testing
db_perm = ViewMenu(name="[my_db].(id:1)")
table_perm = ViewMenu(name="[my_db].[my_table](id:1)")
schema_perm = ViewMenu(name="[my_db].[public]")
database_access = Permission(name="database_access")
datasource_access = Permission(name="datasource_access")
schema_access = Permission(name="schema_access")
session.add_all(
[
db_perm,
table_perm,
schema_perm,
database_access,
datasource_access,
schema_access,
]
)
session.commit()
# Create permission view associations
pv1 = PermissionView(permission_id=database_access.id, view_menu_id=db_perm.id)
pv2 = PermissionView(permission_id=datasource_access.id, view_menu_id=table_perm.id)
pv3 = PermissionView(permission_id=schema_access.id, view_menu_id=schema_perm.id)
session.add_all([pv1, pv2, pv3])
session.commit()
dataset = SqlaTable(
table_name="my_table",
database=database,
id=1,
database_id=database.id,
perm="[my_db].[my_table](id:1)",
catalog=None,
schema="public",
schema_perm="[my_db].[public]",
@@ -279,31 +417,26 @@ def test_upgrade_catalog_perms_graceful(
session.commit()
chart = Slice(
slice_name="my_chart",
datasource_type="table",
datasource_id=dataset.id,
catalog_perm=None,
schema_perm="[my_db].[public]",
)
query = Query(
client_id="foo",
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
saved_query = SavedQuery(
database=database,
sql="SELECT * FROM public.t",
db_id=database.id,
catalog=None,
schema="public",
)
tab_state = TabState(
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
table_schema = TableSchema(
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
session.add_all([chart, query, saved_query, tab_state, table_schema])
session.commit()
@@ -370,13 +503,21 @@ def test_upgrade_catalog_perms_oauth_connection(
schemas. This step should be skipped if the database is set up using OAuth and not
raise an exception.
"""
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
from superset.migrations.shared.catalogs import (
Database,
Query,
SavedQuery,
Slice,
SqlaTable,
TableSchema,
TabState,
)
engine = session.get_bind()
Database.metadata.create_all(engine)
Permission.metadata.create_all(engine)
PermissionView.metadata.create_all(engine)
ViewMenu.metadata.create_all(engine)
mocker.patch("superset.migrations.shared.catalogs.op")
db = mocker.patch("superset.migrations.shared.catalogs.db")
@@ -386,14 +527,64 @@ def test_upgrade_catalog_perms_oauth_connection(
)
mocker.patch("superset.migrations.shared.catalogs.op", session)
# Mock the db_engine_spec for BigQuery with catalog support
mock_db_engine_spec = mocker.MagicMock()
mock_db_engine_spec.supports_catalog = True
mock_db_engine_spec.get_default_catalog.return_value = "my-test-project"
mocker.patch.object(
Database,
"db_engine_spec",
new_callable=mocker.PropertyMock,
return_value=mock_db_engine_spec,
)
mocker.patch.object(
Database,
"get_default_catalog",
return_value="my-test-project",
)
database = Database(
database_name="my_db",
id=1,
sqlalchemy_uri="bigquery://my-test-project",
encrypted_extra=json.dumps({"oauth2_client_info": oauth2_config}),
)
database.database_name = "my_db"
session.add(database)
session.commit()
# Create initial permissions for testing
db_perm = ViewMenu(name="[my_db].(id:1)")
table_perm = ViewMenu(name="[my_db].[my_table](id:1)")
schema_perm = ViewMenu(name="[my_db].[public]")
database_access = Permission(name="database_access")
datasource_access = Permission(name="datasource_access")
schema_access = Permission(name="schema_access")
session.add_all(
[
db_perm,
table_perm,
schema_perm,
database_access,
datasource_access,
schema_access,
]
)
session.commit()
# Create permission view associations
pv1 = PermissionView(permission_id=database_access.id, view_menu_id=db_perm.id)
pv2 = PermissionView(permission_id=datasource_access.id, view_menu_id=table_perm.id)
pv3 = PermissionView(permission_id=schema_access.id, view_menu_id=schema_perm.id)
session.add_all([pv1, pv2, pv3])
session.commit()
dataset = SqlaTable(
table_name="my_table",
database=database,
id=1,
database_id=database.id,
perm="[my_db].[my_table](id:1)",
catalog=None,
schema="public",
schema_perm="[my_db].[public]",
@@ -402,31 +593,26 @@ def test_upgrade_catalog_perms_oauth_connection(
session.commit()
chart = Slice(
slice_name="my_chart",
datasource_type="table",
datasource_id=dataset.id,
catalog_perm=None,
schema_perm="[my_db].[public]",
)
query = Query(
client_id="foo",
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
saved_query = SavedQuery(
database=database,
sql="SELECT * FROM public.t",
db_id=database.id,
catalog=None,
schema="public",
)
tab_state = TabState(
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
table_schema = TableSchema(
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
session.add_all([chart, query, saved_query, tab_state, table_schema])
session.commit()
@@ -494,14 +680,22 @@ def test_upgrade_catalog_perms_simplified_migration(
This should only update existing permissions + create a new permission
for the default catalog.
"""
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
from superset.migrations.shared.catalogs import (
Database,
Query,
SavedQuery,
Slice,
SqlaTable,
TableSchema,
TabState,
)
from superset.migrations.shared.security_converge import Base as SecurityBase
engine = session.get_bind()
Database.metadata.create_all(engine)
SecurityBase.metadata.create_all(engine)
mocker.patch("superset.migrations.shared.catalogs.op")
db = mocker.patch("superset.migrations.shared.catalogs.db")
db.Session.return_value = session
@@ -510,13 +704,65 @@ def test_upgrade_catalog_perms_simplified_migration(
)
mocker.patch("superset.migrations.shared.catalogs.op", session)
# Mock the db_engine_spec for BigQuery with catalog support
mock_db_engine_spec = mocker.MagicMock()
mock_db_engine_spec.supports_catalog = True
mock_db_engine_spec.get_default_catalog.return_value = "my-test-project"
mocker.patch.object(
Database,
"db_engine_spec",
new_callable=mocker.PropertyMock,
return_value=mock_db_engine_spec,
)
mocker.patch.object(
Database,
"get_default_catalog",
return_value="my-test-project",
)
database = Database(
database_name="my_db",
id=1,
sqlalchemy_uri="bigquery://my-test-project",
)
database.database_name = "my_db"
session.add(database)
session.commit()
# Create initial permissions for testing
db_perm = ViewMenu(name="[my_db].(id:1)")
table_perm = ViewMenu(name="[my_db].[my_table](id:1)")
schema_perm = ViewMenu(name="[my_db].[public]")
database_access = Permission(name="database_access")
datasource_access = Permission(name="datasource_access")
schema_access = Permission(name="schema_access")
catalog_access = Permission(name="catalog_access")
session.add_all(
[
db_perm,
table_perm,
schema_perm,
database_access,
datasource_access,
schema_access,
catalog_access,
]
)
session.commit()
# Create permission view associations
pv1 = PermissionView(permission_id=database_access.id, view_menu_id=db_perm.id)
pv2 = PermissionView(permission_id=datasource_access.id, view_menu_id=table_perm.id)
pv3 = PermissionView(permission_id=schema_access.id, view_menu_id=schema_perm.id)
session.add_all([pv1, pv2, pv3])
session.commit()
dataset = SqlaTable(
table_name="my_table",
database=database,
id=1,
database_id=database.id,
perm="[my_db].[my_table](id:1)",
catalog=None,
schema="public",
schema_perm="[my_db].[public]",
@@ -525,31 +771,26 @@ def test_upgrade_catalog_perms_simplified_migration(
session.commit()
chart = Slice(
slice_name="my_chart",
datasource_type="table",
datasource_id=dataset.id,
catalog_perm=None,
schema_perm="[my_db].[public]",
)
query = Query(
client_id="foo",
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
saved_query = SavedQuery(
database=database,
sql="SELECT * FROM public.t",
db_id=database.id,
catalog=None,
schema="public",
)
tab_state = TabState(
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
table_schema = TableSchema(
database=database,
database_id=database.id,
catalog=None,
schema="public",
)
session.add_all([chart, query, saved_query, tab_state, table_schema])
session.commit()