Compare commits

...

2 Commits

Author SHA1 Message Date
Evan
731d115ec4 chore(deps): make codebase SQLAlchemy 2.0-compatible (import/expression layer)
Companion code changes for the SQLAlchemy 1.4 -> 2.0 bump. These are
backward-compatible with SQLAlchemy 1.4 / Flask-SQLAlchemy 2.x, so they do
not regress the current baseline, while unblocking the SA2 import/boot path.

Mechanical SA2 breaking-change fixes:
- Removed `sqlalchemy.orm.eagerload` (alias of `joinedload`) in security manager.
- Legacy list-form `select([...])` -> `select(...)` and `case([...])` ->
  `case(...)` across runtime code (connectors, models/helpers, utils/core,
  extensions/metadb, common/tags, importers).
- `flask_sqlalchemy.BaseQuery` removed in FSA 3.x -> import from
  `flask_sqlalchemy.query.Query` with a fallback to FSA 2.x.
- `sqlalchemy.sql.visitors.VisitableType` removed -> `sqlalchemy.types.TypeEngine`.
- Annotated Declarative: `@declared_attr` columns/relationships now require
  `Mapped[...]` return annotations (created_by_fk/changed_by_fk, BaseDatasource.slices);
  set `__allow_unmapped__` on the FAB declarative base to allow remaining legacy
  1.x annotations during incremental migration.
- Raw-string execution / Engine.execute removal: wrap startup health check and
  secrets-migrator SQL in `text()`, use a Connection, pass a single bind dict
  instead of **kwargs, and access Row values via `row._mapping[...]`.
- Bump in-repo `apache-superset-core` SQLAlchemy pin to >=2.0.50,<3.

This does not complete the full 2.0 migration. The dominant remaining blocker is
SA 2.0 autobegin / connection-transaction semantics in the app/test bootstrap
(FAB `_create_db` -> `Model.metadata.create_all`), plus lockfile regeneration and
db_engine_spec raw-SQL execution. See the PR discussion for the full plan.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-08 18:20:29 -07:00
dependabot[bot]
61194174c7 chore(deps): bump sqlalchemy from 1.4.54 to 2.0.50
Bumps [sqlalchemy](https://github.com/sqlalchemy/sqlalchemy) from 1.4.54 to 2.0.50.
- [Release notes](https://github.com/sqlalchemy/sqlalchemy/releases)
- [Changelog](https://github.com/sqlalchemy/sqlalchemy/blob/main/CHANGES.rst)
- [Commits](https://github.com/sqlalchemy/sqlalchemy/commits)

---
updated-dependencies:
- dependency-name: sqlalchemy
  dependency-version: 2.0.50
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-06-07 10:04:02 +00:00
17 changed files with 71 additions and 43 deletions

View File

@@ -99,7 +99,7 @@ dependencies = [
"sshtunnel>=0.4.0, <0.5", "sshtunnel>=0.4.0, <0.5",
"simplejson>=3.15.0", "simplejson>=3.15.0",
"slack_sdk>=3.19.0, <4", "slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2", "sqlalchemy>=2.0.50, <3",
"sqlalchemy-utils>=0.38.0, <0.43", # expanding lowerbound to work with pydoris "sqlalchemy-utils>=0.38.0, <0.43", # expanding lowerbound to work with pydoris
"sqlglot>=30.8.0, <31", "sqlglot>=30.8.0, <31",
# newer pandas needs 0.9+ # newer pandas needs 0.9+

View File

@@ -46,7 +46,7 @@ dependencies = [
"isodate>=0.7.0", "isodate>=0.7.0",
"pyarrow>=16.0.0", "pyarrow>=16.0.0",
"pydantic>=2.8.0", "pydantic>=2.8.0",
"sqlalchemy>=1.4.0,<2.0", "sqlalchemy>=2.0.50,<3",
"sqlalchemy-utils>=0.38.0, <0.43", # expanding lowerbound to work with pydoris "sqlalchemy-utils>=0.38.0, <0.43", # expanding lowerbound to work with pydoris
"sqlglot>=30.8.0, <31", "sqlglot>=30.8.0, <31",
"typing-extensions>=4.0.0", "typing-extensions>=4.0.0",

View File

@@ -14,10 +14,22 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import flask_appbuilder
from werkzeug.local import LocalProxy from werkzeug.local import LocalProxy
from superset.app import create_app # noqa: F401 # SQLAlchemy 2.0 enables "Annotated Declarative" mapping, which inspects class
from superset.extensions import ( # attribute type annotations and requires mapped attributes to use ``Mapped[...]``.
# Superset's models (and Flask-AppBuilder mixins) still carry legacy 1.x style
# annotations that are not wrapped in ``Mapped[...]``. Setting ``__allow_unmapped__``
# on the shared declarative base preserves the legacy behavior so those annotations
# are ignored by the ORM. This must run before any model class is defined (i.e.
# before importing ``superset.app``), since the annotation check happens at class
# creation time. Models can be migrated incrementally to the typed ``Mapped[...]``
# form.
flask_appbuilder.Model.__allow_unmapped__ = True
from superset.app import create_app # noqa: E402, F401
from superset.extensions import ( # noqa: E402
appbuilder, # noqa: F401 appbuilder, # noqa: F401
cache_manager, cache_manager,
db, # noqa: F401 db, # noqa: F401
@@ -28,7 +40,7 @@ from superset.extensions import (
security_manager, # noqa: F401 security_manager, # noqa: F401
talisman, # noqa: F401 talisman, # noqa: F401
) )
from superset.security import SupersetSecurityManager # noqa: F401 from superset.security import SupersetSecurityManager # noqa: E402, F401
# All of the fields located here should be considered legacy. The correct way to # All of the fields located here should be considered legacy. The correct way to
# declare "global" dependencies is to define it in extensions.py, # declare "global" dependencies is to define it in extensions.py,

View File

@@ -25,7 +25,7 @@ from flask import current_app as app
from pandas.errors import OutOfBoundsDatetime from pandas.errors import OutOfBoundsDatetime
from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text
from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.sql.visitors import VisitableType from sqlalchemy.types import TypeEngine
from superset import db, security_manager from superset import db, security_manager
from superset.commands.dataset.exceptions import ( from superset.commands.dataset.exceptions import (
@@ -65,7 +65,7 @@ type_map = {
} }
def get_sqla_type(native_type: str) -> VisitableType: def get_sqla_type(native_type: str) -> TypeEngine:
if native_type.upper() in type_map: if native_type.upper() in type_map:
return type_map[native_type.upper()] return type_map[native_type.upper()]
@@ -78,7 +78,7 @@ def get_sqla_type(native_type: str) -> VisitableType:
) )
def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, VisitableType]: def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, TypeEngine]:
return { return {
column.column_name: get_sqla_type(column.type) column.column_name: get_sqla_type(column.type)
for column in dataset.columns for column in dataset.columns

View File

@@ -362,7 +362,7 @@ def safe_insert_dashboard_chart_relationships(
# Get existing relationships only for dashboards being updated # Get existing relationships only for dashboards being updated
dashboard_ids = {dashboard_id for dashboard_id, _ in dashboard_chart_ids} dashboard_ids = {dashboard_id for dashboard_id, _ in dashboard_chart_ids}
existing_relationships = db.session.execute( existing_relationships = db.session.execute(
select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id]).where( select(dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id).where(
dashboard_slices.c.dashboard_id.in_(dashboard_ids) dashboard_slices.c.dashboard_id.in_(dashboard_ids)
) )
).fetchall() ).fetchall()

View File

@@ -440,7 +440,7 @@ def add_owners(metadata: MetaData) -> None:
columns = ["tag_id", "object_id", "object_type"] columns = ["tag_id", "object_id", "object_type"]
# create a custom tag for each user # create a custom tag for each user
ids = select([users.c.id]) ids = select(users.c.id)
insert = tag.insert() insert = tag.insert()
for (id_,) in db.session.execute(ids): for (id_,) in db.session.execute(ids):
with contextlib.suppress(IntegrityError): # already exists with contextlib.suppress(IntegrityError): # already exists
@@ -478,7 +478,7 @@ def add_favorites(metadata: MetaData) -> None:
columns = ["tag_id", "object_id", "object_type"] columns = ["tag_id", "object_id", "object_type"]
# create a custom tag for each user # create a custom tag for each user
ids = select([users.c.id]) ids = select(users.c.id)
insert = tag.insert() insert = tag.insert()
for (id_,) in db.session.execute(ids): for (id_,) in db.session.execute(ids):
with contextlib.suppress(IntegrityError): # already exists with contextlib.suppress(IntegrityError): # already exists

View File

@@ -335,7 +335,7 @@ class BaseDatasource(
return self.kind == DatasourceKind.VIRTUAL return self.kind == DatasourceKind.VIRTUAL
@declared_attr @declared_attr
def slices(self) -> RelationshipProperty: def slices(self) -> Mapped[list["Slice"]]:
return relationship( return relationship(
"Slice", "Slice",
overlaps="table", overlaps="table",
@@ -1791,11 +1791,9 @@ class SqlaTable(
# for those we fall back to LIMIT 1. # for those we fall back to LIMIT 1.
tbl, _unused_cte = self.get_from_clause(template_processor) tbl, _unused_cte = self.get_from_clause(template_processor)
if self.db_engine_spec.type_probe_needs_row: if self.db_engine_spec.type_probe_needs_row:
qry = sa.select([sqla_column]).limit(1).select_from(tbl) qry = sa.select(sqla_column).limit(1).select_from(tbl)
else: else:
qry = ( qry = sa.select(sqla_column).where(sa.false()).select_from(tbl)
sa.select([sqla_column]).where(sa.false()).select_from(tbl)
)
sql = self.database.compile_sqla_query( sql = self.database.compile_sqla_query(
qry, qry,
catalog=self.catalog, catalog=self.catalog,

View File

@@ -368,7 +368,7 @@ class SupersetShillelaghAdapter(Adapter):
""" """
Build SQLAlchemy query object. Build SQLAlchemy query object.
""" """
query = select([self._table]) query = select(self._table)
for column_name, filter_ in bounds.items(): for column_name, filter_ in bounds.items():
column = self._table.c[column_name] column = self._table.c[column_name]
@@ -452,7 +452,7 @@ class SupersetShillelaghAdapter(Adapter):
if self._rowid: if self._rowid:
return result.inserted_primary_key[0] return result.inserted_primary_key[0]
query = select([func.count()]).select_from(self._table) query = select(func.count()).select_from(self._table)
return connection.execute(query).scalar() return connection.execute(query).scalar()
@check_dml @check_dml

View File

@@ -22,6 +22,7 @@ import os
import sys import sys
from typing import Any, Callable, TYPE_CHECKING from typing import Any, Callable, TYPE_CHECKING
import sqlalchemy as sa
import wtforms_json import wtforms_json
from colorama import Fore, Style from colorama import Fore, Style
from deprecation import deprecated from deprecation import deprecated
@@ -753,7 +754,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
try: try:
with self.superset_app.app_context(): with self.superset_app.app_context():
# Simple connection test # Simple connection test
db.engine.execute("SELECT 1") with db.engine.connect() as connection:
connection.execute(sa.text("SELECT 1"))
except Exception: except Exception:
db_uri = self.database_uri db_uri = self.database_uri
safe_uri = make_url_safe(db_uri) if db_uri else "Not configured" safe_uri = make_url_safe(db_uri) if db_uri else "Not configured"

View File

@@ -60,7 +60,7 @@ from sqlalchemy import and_, Column, or_, UniqueConstraint
from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapper, Session, validates, with_loader_criteria from sqlalchemy.orm import Mapped, Mapper, Session, validates, with_loader_criteria
from sqlalchemy.orm.session import ORMExecuteState from sqlalchemy.orm.session import ORMExecuteState
from sqlalchemy.sql.elements import ColumnElement, Grouping, literal_column, TextClause from sqlalchemy.sql.elements import ColumnElement, Grouping, literal_column, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.expression import Label, Select, TextAsFrom
@@ -582,7 +582,7 @@ class AuditMixinNullable(AuditMixin):
) )
@declared_attr @declared_attr
def created_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed def created_by_fk(self) -> Mapped[Optional[int]]: # pylint: disable=arguments-renamed
return sa.Column( return sa.Column(
sa.Integer, sa.Integer,
sa.ForeignKey("ab_user.id"), sa.ForeignKey("ab_user.id"),
@@ -591,7 +591,7 @@ class AuditMixinNullable(AuditMixin):
) )
@declared_attr @declared_attr
def changed_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed def changed_by_fk(self) -> Mapped[Optional[int]]: # pylint: disable=arguments-renamed
return sa.Column( return sa.Column(
sa.Integer, sa.Integer,
sa.ForeignKey("ab_user.id"), sa.ForeignKey("ab_user.id"),
@@ -2682,7 +2682,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
condition = condition_factory(expr.name, expr) condition = condition_factory(expr.name, expr)
# Create CASE expression: condition true -> original, else "Others" # Create CASE expression: condition true -> original, else "Others"
case_expr = sa.case([(condition, expr)], else_=sa.literal("Others")) case_expr = sa.case((condition, expr), else_=sa.literal("Others"))
case_expr = self.make_sqla_column_compatible(case_expr, expr.name) case_expr = self.make_sqla_column_compatible(case_expr, expr.name)
modified_select_exprs.append(case_expr) modified_select_exprs.append(case_expr)
else: else:
@@ -2900,15 +2900,15 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
) -> Select: ) -> Select:
"""Build validation query based on expression type. Raises on error.""" """Build validation query based on expression type. Raises on error."""
if expression_type == SqlExpressionType.COLUMN: if expression_type == SqlExpressionType.COLUMN:
return sa.select([sa.literal_column(expression).label("test_col")]) return sa.select(sa.literal_column(expression).label("test_col"))
elif expression_type == SqlExpressionType.METRIC: elif expression_type == SqlExpressionType.METRIC:
return sa.select([sa.literal_column(expression).label("test_metric")]) return sa.select(sa.literal_column(expression).label("test_metric"))
elif expression_type == SqlExpressionType.WHERE: elif expression_type == SqlExpressionType.WHERE:
return sa.select([sa.literal(1)]).where(sa.text(expression)) return sa.select(sa.literal(1)).where(sa.text(expression))
elif expression_type == SqlExpressionType.HAVING: elif expression_type == SqlExpressionType.HAVING:
dummy_col = sa.literal("A").label("dummy") dummy_col = sa.literal("A").label("dummy")
return ( return (
sa.select([dummy_col]) sa.select(dummy_col)
.group_by(sa.text("dummy")) .group_by(sa.text("dummy"))
.having(sa.text(expression)) .having(sa.text(expression))
) )
@@ -3846,7 +3846,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
) )
label = "rowcount" label = "rowcount"
col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label)
qry = sa.select([col]).select_from(qry.alias("rowcount_qry")) qry = sa.select(col).select_from(qry.alias("rowcount_qry"))
labels_expected = [label] labels_expected = [label]
filter_columns = [flt.get("col") for flt in filter] if filter else [] filter_columns = [flt.get("col") for flt in filter] if filter else []

View File

@@ -16,7 +16,12 @@
# under the License. # under the License.
from typing import Any from typing import Any
from flask_sqlalchemy import BaseQuery try:
# Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0)
from flask_sqlalchemy.query import Query as BaseQuery
except ImportError: # pragma: no cover
# Flask-SQLAlchemy 2.x
from flask_sqlalchemy import BaseQuery
from superset import security_manager from superset import security_manager
from superset.models.sql_lab import Query from superset.models.sql_lab import Query

View File

@@ -18,10 +18,16 @@ from typing import Any
from flask import g from flask import g
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from flask_sqlalchemy import BaseQuery
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
try:
# Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0)
from flask_sqlalchemy.query import Query as BaseQuery
except ImportError: # pragma: no cover
# Flask-SQLAlchemy 2.x
from flask_sqlalchemy import BaseQuery
from superset.models.sql_lab import SavedQuery from superset.models.sql_lab import SavedQuery
from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter
from superset.views.base import BaseFilter from superset.views.base import BaseFilter

View File

@@ -53,7 +53,7 @@ from flask_login import AnonymousUserMixin, LoginManager
from jwt.api_jwt import _jwt_global_obj from jwt.api_jwt import _jwt_global_obj
from sqlalchemy import and_, func as sa_func, inspect, or_ from sqlalchemy import and_, func as sa_func, inspect, or_
from sqlalchemy.engine.base import Connection from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import eagerload, joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.orm.exc import MultipleResultsFound from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query as SqlaQuery from sqlalchemy.orm.query import Query as SqlaQuery
@@ -1569,8 +1569,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
pvms = ( pvms = (
self.session.query(self.permissionview_model) self.session.query(self.permissionview_model)
.options( .options(
eagerload(self.permissionview_model.permission), joinedload(self.permissionview_model.permission),
eagerload(self.permissionview_model.view_menu), joinedload(self.permissionview_model.view_menu),
) )
.all() .all()
) )

View File

@@ -767,7 +767,7 @@ def pessimistic_connection_handling(some_engine: Engine) -> None:
# run a SELECT 1. use a core select() so that # run a SELECT 1. use a core select() so that
# the SELECT of a scalar value without a table is # the SELECT of a scalar value without a table is
# appropriately formatted for the backend # appropriately formatted for the backend
connection.scalar(select([1])) connection.scalar(select(1))
except exc.DBAPIError as err: except exc.DBAPIError as err:
# catch SQLAlchemy's DBAPIError, which is a wrapper # catch SQLAlchemy's DBAPIError, which is a wrapper
# for the DBAPI's exception. It includes a .connection_invalidated # for the DBAPI's exception. It includes a .connection_invalidated
@@ -779,7 +779,7 @@ def pessimistic_connection_handling(some_engine: Engine) -> None:
# itself and establish a new connection. The disconnect detection # itself and establish a new connection. The disconnect detection
# here also causes the whole connection pool to be invalidated # here also causes the whole connection pool to be invalidated
# so that all stale connections are discarded. # so that all stale connections are discarded.
connection.scalar(select([1])) connection.scalar(select(1))
else: else:
raise raise
finally: finally:

View File

@@ -175,7 +175,7 @@ class SecretsMigrator:
table_name: str, table_name: str,
) -> Row: ) -> Row:
cols = ",".join(pk_columns + column_names) cols = ",".join(pk_columns + column_names)
return conn.execute(f"SELECT {cols} FROM {table_name}") # noqa: S608 return conn.execute(text(f"SELECT {cols} FROM {table_name}")) # noqa: S608
def _re_encrypt_row( def _re_encrypt_row(
self, self,
@@ -216,7 +216,7 @@ class SecretsMigrator:
re_encrypted_columns = {} re_encrypted_columns = {}
for column_name, encrypted_type in columns.items(): for column_name, encrypted_type in columns.items():
raw_value = self._read_bytes(column_name, row[column_name]) raw_value = self._read_bytes(column_name, row._mapping[column_name])
# NULL values aren't encrypted; there is nothing to migrate. # NULL values aren't encrypted; there is nothing to migrate.
if raw_value is None: if raw_value is None:
@@ -265,13 +265,12 @@ class SecretsMigrator:
set_cols = ",".join(f"{name} = :{name}" for name in re_encrypted_columns) set_cols = ",".join(f"{name} = :{name}" for name in re_encrypted_columns)
where_clause = " AND ".join(f"{pk} = :_pk_{pk}" for pk in pk_columns) where_clause = " AND ".join(f"{pk} = :_pk_{pk}" for pk in pk_columns)
pk_bind = {f"_pk_{pk}": row[pk] for pk in pk_columns} pk_bind = {f"_pk_{pk}": row._mapping[pk] for pk in pk_columns}
conn.execute( conn.execute(
text( text(
f"UPDATE {table_name} SET {set_cols} WHERE {where_clause}" # noqa: S608 f"UPDATE {table_name} SET {set_cols} WHERE {where_clause}" # noqa: S608
), ),
**pk_bind, {**pk_bind, **re_encrypted_columns},
**re_encrypted_columns,
) )
def run(self) -> ReEncryptStats: def run(self) -> ReEncryptStats:

View File

@@ -31,7 +31,7 @@ from flask_appbuilder import Model
from sqlalchemy import Column, inspect, MetaData, Table as DBTable from sqlalchemy import Column, inspect, MetaData, Table as DBTable
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.sql.visitors import VisitableType from sqlalchemy.types import TypeEngine
from superset import db from superset import db
from superset.sql.parse import Table from superset.sql.parse import Table
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
class ColumnInfo(TypedDict): class ColumnInfo(TypedDict):
name: str name: str
type: VisitableType type: TypeEngine
nullable: bool nullable: bool
default: Optional[Any] default: Optional[Any]
autoincrement: str autoincrement: str

View File

@@ -23,7 +23,6 @@ from uuid import uuid4
import pytest import pytest
from flask.ctx import AppContext from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
from flask_sqlalchemy import BaseQuery
from freezegun import freeze_time from freezegun import freeze_time
from slack_sdk.errors import ( from slack_sdk.errors import (
BotUserAccessError, BotUserAccessError,
@@ -37,6 +36,13 @@ from slack_sdk.errors import (
) )
from sqlalchemy.sql import func from sqlalchemy.sql import func
try:
# Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0)
from flask_sqlalchemy.query import Query as BaseQuery
except ImportError: # pragma: no cover
# Flask-SQLAlchemy 2.x
from flask_sqlalchemy import BaseQuery
from superset import db from superset import db
from superset.commands.report.exceptions import ( from superset.commands.report.exceptions import (
AlertQueryError, AlertQueryError,