Compare commits

...

4 Commits

Author SHA1 Message Date
Evan
d9e13b8271 test: read re-encrypt bind params from positional args, not kwargs
The migrator passes bind params positionally to conn.execute(stmt, params),
so the assertion must read call_args.args[1] rather than call_args.kwargs.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 21:33:46 -07:00
Evan
b07d4dae3b test: pass real SQLAlchemy Row to _re_encrypt_row in integration tests
The re-encrypt path reads column values via row._mapping (SQLAlchemy 2.0
Row API), but the integration tests passed a plain dict, which has no
_mapping, raising AttributeError. Build a genuine Row via a small helper
that handles both the 1.4 and 2.0 Row constructor signatures.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 20:25:48 -07:00
Evan
91de3513a2 test: update encrypt re-encrypt fixtures for SQLAlchemy 2.0 Row API
The migrator now reads columns via row._mapping[...] and binds the
UPDATE params as a single positional dict. Wrap the fixture rows in a
_Row stand-in and read the bind dict positionally so the unit tests
match the production interface.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 19:31:27 -07:00
Claude Code
18a2a4509e refactor: make import/expression layer SQLAlchemy 2.0-compatible
Forward-compatible groundwork for the SQLAlchemy 1.4 -> 2.0 bump. Every
change here is backward-compatible with the current SQLAlchemy 1.4.54 /
Flask-SQLAlchemy 2.5.1 baseline, so it does not regress anything while
shrinking the eventual version bump.

Removed-in-2.0 API fixes (these break on 2.0 but still work on 1.4):
- `sqlalchemy.orm.eagerload` (removed alias of `joinedload`) in the
  security manager.
- `sqlalchemy.sql.visitors.VisitableType` (removed) -> `sqlalchemy.types.TypeEngine`
  in dataset importer and mock_data type hints.
- `flask_sqlalchemy.BaseQuery` (removed in FSA 3.x, required by SA 2.0) ->
  import from `flask_sqlalchemy.query.Query` with a fallback to the FSA 2.x
  path, in the query/saved-query filters and the report command tests.
- `Engine.execute` / raw-string execution -> wrap startup health check in a
  `Connection` + `text()`, and the secrets re-encryptor in `text()` with a
  single bind dict and `row._mapping[...]` access.

Prospective Annotated Declarative prep (no-ops on 1.4, ease the 2.0 flip):
- `__allow_unmapped__` on the FAB declarative base so legacy 1.x annotations
  are tolerated during incremental migration.
- `Mapped[...]` return annotations on `created_by_fk`/`changed_by_fk` and
  `BaseDatasource.slices`.

The `select()` and `case()` syntax changes from the same effort already
landed via #40276 / #40275, so they are intentionally excluded here.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-17 18:24:27 -07:00
13 changed files with 112 additions and 45 deletions

View File

@@ -14,10 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import flask_appbuilder
from werkzeug.local import LocalProxy
from superset.app import create_app # noqa: F401
from superset.extensions import (
# SQLAlchemy 2.0 enables "Annotated Declarative" mapping, which inspects class
# attribute type annotations and requires mapped attributes to use ``Mapped[...]``.
# Superset's models (and Flask-AppBuilder mixins) still carry legacy 1.x style
# annotations that are not wrapped in ``Mapped[...]``. Setting ``__allow_unmapped__``
# on the shared declarative base preserves the legacy behavior so those annotations
# are ignored by the ORM. This must run before any model class is defined (i.e.
# before importing ``superset.app``), since the annotation check happens at class
# creation time. Models can be migrated incrementally to the typed ``Mapped[...]``
# form.
flask_appbuilder.Model.__allow_unmapped__ = True
from superset.app import create_app # noqa: E402, F401
from superset.extensions import ( # noqa: E402
appbuilder, # noqa: F401
cache_manager,
db, # noqa: F401
@@ -28,7 +40,7 @@ from superset.extensions import (
security_manager, # noqa: F401
talisman, # noqa: F401
)
from superset.security import SupersetSecurityManager # noqa: F401
from superset.security import SupersetSecurityManager # noqa: E402, F401
# All of the fields located here should be considered legacy. The correct way to
# declare "global" dependencies is to define it in extensions.py,

View File

@@ -28,7 +28,7 @@ from flask import current_app as app
from pandas.errors import OutOfBoundsDatetime
from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.sql.visitors import VisitableType
from sqlalchemy.types import TypeEngine
from superset import db, security_manager
from superset.commands.dataset.exceptions import (
@@ -94,7 +94,7 @@ type_map = {
}
def get_sqla_type(native_type: str) -> VisitableType:
def get_sqla_type(native_type: str) -> TypeEngine:
if native_type.upper() in type_map:
return type_map[native_type.upper()]
@@ -107,7 +107,7 @@ def get_sqla_type(native_type: str) -> VisitableType:
)
def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, VisitableType]:
def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, TypeEngine]:
return {
column.column_name: get_sqla_type(column.type)
for column in dataset.columns

View File

@@ -335,7 +335,7 @@ class BaseDatasource(
return self.kind == DatasourceKind.VIRTUAL
@declared_attr
def slices(self) -> RelationshipProperty:
def slices(self) -> Mapped[list["Slice"]]:
return relationship(
"Slice",
overlaps="table",

View File

@@ -22,6 +22,7 @@ import os
import sys
from typing import Any, Callable, TYPE_CHECKING
import sqlalchemy as sa
import wtforms_json
from colorama import Fore, Style
from deprecation import deprecated
@@ -808,7 +809,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
try:
with self.superset_app.app_context():
# Simple connection test
db.engine.execute("SELECT 1")
with db.engine.connect() as connection:
connection.execute(sa.text("SELECT 1"))
except Exception:
db_uri = self.database_uri
safe_uri = make_url_safe(db_uri) if db_uri else "Not configured"

View File

@@ -60,7 +60,7 @@ from sqlalchemy import and_, Column, or_, UniqueConstraint
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapper, Session, validates, with_loader_criteria
from sqlalchemy.orm import Mapped, Mapper, Session, validates, with_loader_criteria
from sqlalchemy.orm.session import ORMExecuteState
from sqlalchemy.sql.elements import ColumnElement, Grouping, literal_column, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
@@ -583,7 +583,7 @@ class AuditMixinNullable(AuditMixin):
)
@declared_attr
def created_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed
def created_by_fk(self) -> Mapped[Optional[int]]: # pylint: disable=arguments-renamed
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
@@ -592,7 +592,7 @@ class AuditMixinNullable(AuditMixin):
)
@declared_attr
def changed_by_fk(self) -> sa.Column: # pylint: disable=arguments-renamed
def changed_by_fk(self) -> Mapped[Optional[int]]: # pylint: disable=arguments-renamed
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),

View File

@@ -16,7 +16,12 @@
# under the License.
from typing import Any
from flask_sqlalchemy import BaseQuery
try:
# Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0)
from flask_sqlalchemy.query import Query as BaseQuery
except ImportError: # pragma: no cover
# Flask-SQLAlchemy 2.x
from flask_sqlalchemy import BaseQuery
from superset import security_manager
from superset.models.sql_lab import Query

View File

@@ -18,10 +18,16 @@ from typing import Any
from flask import g
from flask_babel import lazy_gettext as _
from flask_sqlalchemy import BaseQuery
from sqlalchemy import or_
from sqlalchemy.orm.query import Query
try:
# Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0)
from flask_sqlalchemy.query import Query as BaseQuery
except ImportError: # pragma: no cover
# Flask-SQLAlchemy 2.x
from flask_sqlalchemy import BaseQuery
from superset.models.sql_lab import SavedQuery
from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter
from superset.views.base import BaseFilter

View File

@@ -54,7 +54,7 @@ from flask_login import AnonymousUserMixin, LoginManager
from jwt.api_jwt import _jwt_global_obj
from sqlalchemy import and_, func as sa_func, inspect, or_
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import eagerload, joinedload
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query as SqlaQuery
@@ -1800,8 +1800,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
pvms = (
self.session.query(self.permissionview_model)
.options(
eagerload(self.permissionview_model.permission),
eagerload(self.permissionview_model.view_menu),
joinedload(self.permissionview_model.permission),
joinedload(self.permissionview_model.view_menu),
)
.all()
)

View File

@@ -302,7 +302,7 @@ class SecretsMigrator:
table_name: str,
) -> Row:
cols = ",".join(pk_columns + column_names)
return conn.execute(f"SELECT {cols} FROM {table_name}") # noqa: S608
return conn.execute(text(f"SELECT {cols} FROM {table_name}")) # noqa: S608
def _target_type(self, encrypted_type: EncryptedType) -> EncryptedType:
"""The EncryptedType to re-encrypt a value *into*.
@@ -430,7 +430,7 @@ class SecretsMigrator:
re_encrypted_columns = {}
for column_name, encrypted_type in columns.items():
raw_value = self._read_bytes(column_name, row[column_name])
raw_value = self._read_bytes(column_name, row._mapping[column_name])
# NULL values aren't encrypted; there is nothing to migrate.
if raw_value is None:
@@ -508,13 +508,12 @@ class SecretsMigrator:
set_cols = ",".join(f"{name} = :{name}" for name in re_encrypted_columns)
where_clause = " AND ".join(f"{pk} = :_pk_{pk}" for pk in pk_columns)
pk_bind = {f"_pk_{pk}": row[pk] for pk in pk_columns}
pk_bind = {f"_pk_{pk}": row._mapping[pk] for pk in pk_columns}
conn.execute(
text(
f"UPDATE {table_name} SET {set_cols} WHERE {where_clause}" # noqa: S608
),
**pk_bind,
**re_encrypted_columns,
{**pk_bind, **re_encrypted_columns},
)
def run(self) -> ReEncryptStats:

View File

@@ -31,7 +31,7 @@ from flask_appbuilder import Model
from sqlalchemy import Column, inspect, MetaData, Table as DBTable
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import func
from sqlalchemy.sql.visitors import VisitableType
from sqlalchemy.types import TypeEngine
from superset import db
from superset.sql.parse import Table
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
class ColumnInfo(TypedDict):
name: str
type: VisitableType
type: TypeEngine
nullable: bool
default: Optional[Any]
autoincrement: str

View File

@@ -23,7 +23,6 @@ from uuid import uuid4
import pytest
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from flask_sqlalchemy import BaseQuery
from freezegun import freeze_time
from slack_sdk.errors import (
BotUserAccessError,
@@ -37,6 +36,13 @@ from slack_sdk.errors import (
)
from sqlalchemy.sql import func
try:
# Flask-SQLAlchemy 3.x (required by SQLAlchemy 2.0)
from flask_sqlalchemy.query import Query as BaseQuery
except ImportError: # pragma: no cover
# Flask-SQLAlchemy 2.x
from flask_sqlalchemy import BaseQuery
from superset import db
from superset.commands.report.exceptions import (
AlertQueryError,

View File

@@ -31,6 +31,28 @@ from superset.utils.encrypt import (
from tests.integration_tests.base_tests import SupersetTestCase
def make_row(values: dict[str, Any]) -> Any:
"""Build a genuine SQLAlchemy ``Row`` from a mapping.
``SecretsMigrator._re_encrypt_row`` consumes the ``Row`` objects yielded by
``conn.execute(...)``, reading column values through ``row._mapping`` per the
SQLAlchemy 2.0 Row API. Tests must therefore pass a real ``Row`` rather than a
plain ``dict`` (which lacks ``_mapping``). The constructor signature differs
between SQLAlchemy 1.4 and 2.0, so both are handled here.
"""
from sqlalchemy.engine.result import SimpleResultMetaData
from sqlalchemy.engine.row import Row
metadata = SimpleResultMetaData(tuple(values))
data = tuple(values.values())
try:
# SQLAlchemy 2.0: Row(parent, processors, key_to_index, data)
return Row(metadata, None, metadata._key_to_index, data)
except AttributeError:
# SQLAlchemy 1.4: Row(parent, processors, keymap, key_style, data)
return Row(metadata, None, metadata._keymap, Row._default_key_style, data)
class CustomEncFieldAdapter(AbstractEncryptedFieldAdapter):
def create(
self,
@@ -224,7 +246,8 @@ class EncryptedFieldTest(SupersetTestCase):
current_field = encrypted_field_factory.create(String(1024))
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": ciphertext}
pk_value = b"\x00" * 16
row = make_row({"uuid": pk_value, "configuration": ciphertext})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
@@ -239,9 +262,11 @@ class EncryptedFieldTest(SupersetTestCase):
assert conn.execute.call_count == 1
stmt = str(conn.execute.call_args.args[0])
assert "WHERE uuid = :_pk_uuid" in stmt
kwargs = conn.execute.call_args.kwargs
assert kwargs["_pk_uuid"] == row["uuid"]
assert "configuration" in kwargs
# The migrator passes bind params positionally (conn.execute(stmt, params)),
# so read them from args[1] rather than kwargs.
params = conn.execute.call_args.args[1]
assert params["_pk_uuid"] == pk_value
assert "configuration" in params
assert stats == ReEncryptStats(re_encrypted=1, skipped=0, failed=0)
def test_re_encrypt_row_is_idempotent(self):
@@ -264,7 +289,7 @@ class EncryptedFieldTest(SupersetTestCase):
assert field.process_result_value(ciphertext, dialect) == "hunter2"
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": ciphertext}
row = make_row({"uuid": b"\x00" * 16, "configuration": ciphertext})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
@@ -308,7 +333,7 @@ class EncryptedFieldTest(SupersetTestCase):
ciphertext = field.process_bind_param("hunter2", dialect)
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": ciphertext}
row = make_row({"uuid": b"\x00" * 16, "configuration": ciphertext})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
@@ -342,7 +367,7 @@ class EncryptedFieldTest(SupersetTestCase):
field = encrypted_field_factory.create(String(1024))
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": b"not-valid-ciphertext"}
row = make_row({"uuid": b"\x00" * 16, "configuration": b"not-valid-ciphertext"})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
@@ -374,7 +399,7 @@ class EncryptedFieldTest(SupersetTestCase):
field = encrypted_field_factory.create(String(1024))
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": None}
row = make_row({"uuid": b"\x00" * 16, "configuration": None})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001

View File

@@ -57,6 +57,18 @@ def _engine_migrator(target_engine: type) -> SecretsMigrator:
return migrator
class _Row:
"""Minimal stand-in for a SQLAlchemy ``Row``.
``_re_encrypt_row`` accesses columns via ``row._mapping[...]`` (the
SQLAlchemy 2.0-compatible idiom), so the fixtures wrap their column dicts
in an object exposing that attribute rather than passing a bare ``dict``.
"""
def __init__(self, mapping: dict[str, object]) -> None:
self._mapping = mapping
def test_default_engine_is_aes_cbc() -> None:
"""Without config, the adapter keeps the historical AES-CBC engine."""
field = SQLAlchemyUtilsAdapter().create(SECRET, String(128))
@@ -156,7 +168,7 @@ def test_engine_migration_cbc_to_gcm_re_encrypts() -> None:
migrator = _engine_migrator(AesGcmEngine)
conn = MagicMock()
row = {"id": 1, "password": ciphertext}
row = _Row({"id": 1, "password": ciphertext})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
@@ -165,7 +177,7 @@ def test_engine_migration_cbc_to_gcm_re_encrypts() -> None:
assert stats == ReEncryptStats(re_encrypted=1)
assert conn.execute.call_count == 1
new_value = conn.execute.call_args.kwargs["password"]
new_value = conn.execute.call_args.args[1]["password"]
# The stored value changed and now decrypts as GCM back to the plaintext.
assert new_value != ciphertext
gcm = _encrypted_type(AesGcmEngine)
@@ -183,7 +195,7 @@ def test_engine_migration_idempotent_for_already_target() -> None:
migrator = _engine_migrator(AesGcmEngine)
conn = MagicMock()
row = {"id": 1, "password": gcm_value}
row = _Row({"id": 1, "password": gcm_value})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
@@ -206,7 +218,7 @@ def test_engine_migration_reads_cbc_after_config_already_flipped() -> None:
migrator = _engine_migrator(AesGcmEngine)
conn = MagicMock()
row = {"id": 1, "password": cbc_value}
row = _Row({"id": 1, "password": cbc_value})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
@@ -214,7 +226,7 @@ def test_engine_migration_reads_cbc_after_config_already_flipped() -> None:
)
assert stats == ReEncryptStats(re_encrypted=1)
new_value = conn.execute.call_args.kwargs["password"]
new_value = conn.execute.call_args.args[1]["password"]
assert gcm_column.process_result_value(new_value, DIALECT) == "hunter2"
@@ -231,7 +243,7 @@ def test_engine_migration_gcm_to_cbc_rolls_back() -> None:
migrator = _engine_migrator(AesEngine)
conn = MagicMock()
row = {"id": 1, "password": gcm_value}
row = _Row({"id": 1, "password": gcm_value})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
@@ -239,7 +251,7 @@ def test_engine_migration_gcm_to_cbc_rolls_back() -> None:
)
assert stats == ReEncryptStats(re_encrypted=1)
new_value = conn.execute.call_args.kwargs["password"]
new_value = conn.execute.call_args.args[1]["password"]
assert new_value != gcm_value
# The rolled-back value now decrypts as AES-CBC back to the plaintext.
assert _encrypted_type(AesEngine).process_result_value(new_value, DIALECT) == (
@@ -272,7 +284,7 @@ def test_rollback_authenticated_probe_wins_over_spurious_cbc_skip() -> None:
spurious_target.process_bind_param.return_value = b"new-cbc-ciphertext"
conn = MagicMock()
row = {"id": 1, "password": gcm_value}
row = _Row({"id": 1, "password": gcm_value})
stats = ReEncryptStats()
with mock.patch.object(migrator, "_target_type", return_value=spurious_target):
@@ -302,7 +314,7 @@ def test_combined_key_rotation_and_engine_migration() -> None:
migrator._previous_secret_key = old_key # noqa: SLF001 # rotate key too
conn = MagicMock()
row = {"id": 1, "password": old_value}
row = _Row({"id": 1, "password": old_value})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
@@ -310,7 +322,7 @@ def test_combined_key_rotation_and_engine_migration() -> None:
)
assert stats == ReEncryptStats(re_encrypted=1)
new_value = conn.execute.call_args.kwargs["password"]
new_value = conn.execute.call_args.args[1]["password"]
# The migrated value decrypts as GCM under the *current* key.
assert _encrypted_type(AesGcmEngine).process_result_value(new_value, DIALECT) == (
"hunter2"
@@ -346,7 +358,7 @@ def test_key_rotation_for_aes_gcm_column() -> None:
migrator = _key_rotation_migrator(previous_secret_key=old_key)
conn = MagicMock()
row = {"id": 1, "password": old_value}
row = _Row({"id": 1, "password": old_value})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
@@ -354,7 +366,7 @@ def test_key_rotation_for_aes_gcm_column() -> None:
)
assert stats == ReEncryptStats(re_encrypted=1)
new_value = conn.execute.call_args.kwargs["password"]
new_value = conn.execute.call_args.args[1]["password"]
assert gcm_column.process_result_value(new_value, DIALECT) == "hunter2"
@@ -362,7 +374,7 @@ def test_engine_migration_unreadable_value_counts_as_failure() -> None:
"""A value no engine/key can read is a failure, not a silent pass-through."""
migrator = _engine_migrator(AesGcmEngine)
conn = MagicMock()
row = {"id": 1, "password": b"not-valid-ciphertext"}
row = _Row({"id": 1, "password": b"not-valid-ciphertext"})
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001