mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
feat: catalog support for Databricks native (#28394)
This commit is contained in:
@@ -18,23 +18,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Type
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.migrations.shared.security_converge import add_pvms, ViewMenu
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade_schema_perms(engine: str | None = None) -> None:
|
||||
"""
|
||||
Update schema permissions to include the catalog part.
|
||||
Base: Type[Any] = declarative_base()
|
||||
|
||||
|
||||
class SqlaTable(Base):
|
||||
__tablename__ = "tables"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
database_id = sa.Column(sa.Integer, nullable=False)
|
||||
schema_perm = sa.Column(sa.String(1000))
|
||||
schema = sa.Column(sa.String(255))
|
||||
catalog = sa.Column(sa.String(256), nullable=True, default=None)
|
||||
|
||||
|
||||
class Query(Base):
|
||||
__tablename__ = "query"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
database_id = sa.Column(sa.Integer, nullable=False)
|
||||
catalog = sa.Column(sa.String(256), nullable=True, default=None)
|
||||
|
||||
|
||||
class SavedQuery(Base):
|
||||
__tablename__ = "saved_query"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
db_id = sa.Column(sa.Integer, nullable=False)
|
||||
catalog = sa.Column(sa.String(256), nullable=True, default=None)
|
||||
|
||||
|
||||
class TabState(Base):
|
||||
__tablename__ = "tab_state"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
database_id = sa.Column(sa.Integer, nullable=False)
|
||||
catalog = sa.Column(sa.String(256), nullable=True, default=None)
|
||||
|
||||
|
||||
class TableSchema(Base):
|
||||
__tablename__ = "table_schema"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
database_id = sa.Column(sa.Integer, nullable=False)
|
||||
catalog = sa.Column(sa.String(256), nullable=True, default=None)
|
||||
|
||||
|
||||
class Slice(Base):
|
||||
__tablename__ = "slices"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
datasource_id = sa.Column(sa.Integer)
|
||||
datasource_type = sa.Column(sa.String(200))
|
||||
schema_perm = sa.Column(sa.String(1000))
|
||||
|
||||
|
||||
def upgrade_catalog_perms(engine: str | None = None) -> None:
|
||||
"""
|
||||
Update models when catalogs are introduced in a DB engine spec.
|
||||
|
||||
When an existing DB engine spec starts to support catalogs we need to:
|
||||
|
||||
- Add a `catalog_access` permission for each catalog.
|
||||
- Populate the `catalog` field with the default catalog for each related model.
|
||||
- Update `schema_perm` to include the default catalog.
|
||||
|
||||
Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the
|
||||
introduction of catalogs, any existing permissions need to be renamed to include the
|
||||
catalog: `[db].[catalog].[schema]`.
|
||||
"""
|
||||
bind = op.get_bind()
|
||||
session = db.Session(bind=bind)
|
||||
@@ -46,6 +107,16 @@ def upgrade_schema_perms(engine: str | None = None) -> None:
|
||||
continue
|
||||
|
||||
catalog = database.get_default_catalog()
|
||||
if catalog is None:
|
||||
continue
|
||||
|
||||
perm = security_manager.get_catalog_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
)
|
||||
add_pvms(session, {perm: ("catalog_access",)})
|
||||
|
||||
# update schema_perms
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
for schema in database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
@@ -57,29 +128,47 @@ def upgrade_schema_perms(engine: str | None = None) -> None:
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access",
|
||||
perm,
|
||||
)
|
||||
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
|
||||
if existing_pvm:
|
||||
existing_pvm.view_menu.name = security_manager.get_schema_perm(
|
||||
existing_pvm.name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
|
||||
# update existing models
|
||||
models = [
|
||||
(Query, "database_id"),
|
||||
(SavedQuery, "db_id"),
|
||||
(TabState, "database_id"),
|
||||
(TableSchema, "database_id"),
|
||||
(SqlaTable, "database_id"),
|
||||
]
|
||||
for model, column in models:
|
||||
for instance in session.query(model).filter(
|
||||
getattr(model, column) == database.id
|
||||
):
|
||||
instance.catalog = catalog
|
||||
|
||||
for table in session.query(SqlaTable).filter_by(database_id=database.id):
|
||||
schema_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
table.schema,
|
||||
)
|
||||
table.schema_perm = schema_perm
|
||||
for chart in session.query(Slice).filter_by(
|
||||
datasource_id=table.id,
|
||||
datasource_type="table",
|
||||
):
|
||||
chart.schema_perm = schema_perm
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade_schema_perms(engine: str | None = None) -> None:
|
||||
def downgrade_catalog_perms(engine: str | None = None) -> None:
|
||||
"""
|
||||
Update schema permissions to not have the catalog part.
|
||||
|
||||
Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the
|
||||
introduction of catalogs, any existing permissions need to be renamed to include the
|
||||
catalog: `[db].[catalog].[schema]`.
|
||||
|
||||
This helped function reverts the process.
|
||||
Reverse the process of `upgrade_catalog_perms`.
|
||||
"""
|
||||
bind = op.get_bind()
|
||||
session = db.Session(bind=bind)
|
||||
@@ -91,6 +180,10 @@ def downgrade_schema_perms(engine: str | None = None) -> None:
|
||||
continue
|
||||
|
||||
catalog = database.get_default_catalog()
|
||||
if catalog is None:
|
||||
continue
|
||||
|
||||
# update schema_perms
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
for schema in database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
@@ -102,15 +195,39 @@ def downgrade_schema_perms(engine: str | None = None) -> None:
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access",
|
||||
perm,
|
||||
)
|
||||
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
|
||||
if existing_pvm:
|
||||
existing_pvm.view_menu.name = security_manager.get_schema_perm(
|
||||
existing_pvm.name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
|
||||
# update existing models
|
||||
models = [
|
||||
(Query, "database_id"),
|
||||
(SavedQuery, "db_id"),
|
||||
(TabState, "database_id"),
|
||||
(TableSchema, "database_id"),
|
||||
(SqlaTable, "database_id"),
|
||||
]
|
||||
for model, column in models:
|
||||
for instance in session.query(model).filter(
|
||||
getattr(model, column) == database.id
|
||||
):
|
||||
instance.catalog = None
|
||||
|
||||
for table in session.query(SqlaTable).filter_by(database_id=database.id):
|
||||
schema_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
table.schema,
|
||||
)
|
||||
table.schema_perm = schema_perm
|
||||
for chart in session.query(Slice).filter_by(
|
||||
datasource_id=table.id,
|
||||
datasource_type="table",
|
||||
):
|
||||
chart.schema_perm = schema_perm
|
||||
|
||||
session.commit()
|
||||
|
||||
Reference in New Issue
Block a user