Compare commits

..

1 Commits

Author SHA1 Message Date
Elizabeth Thompson
57b93055b6 fix(dashboard): let CSV exports use query cache instead of always force-querying
Two bugs caused dashboard CSV exports to always re-execute the query instead
of using the cache (while the Explore page correctly used the cache):

1. Chart.tsx passed `force: true` to `exportChart()`, unconditionally
   bypassing the cache on every dashboard export.

2. buildQuery.ts coerced a missing `row_limit` to `0` via `|| 0` for
   download queries, while display queries left it as `undefined`. This
   produced a different cache key so exports would miss the cache even
   without the `force` flag.

Both fixes are needed together: removing `force: true` lets the backend
consult the cache, and preserving `undefined` for a missing row_limit
ensures the export query produces the same cache key as the display query
that populated the cache.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-27 00:20:12 +00:00
17 changed files with 30 additions and 1755 deletions

View File

@@ -519,80 +519,6 @@ For a connection to a SQL endpoint you need to use the HTTP path from the endpoi
{"connect_args": {"http_path": "/sql/1.0/endpoints/****", "driver_path": "/path/to/odbc/driver"}}
```
##### OAuth2 Authentication
Superset supports OAuth2 authentication for Databricks, allowing users to authenticate with their personal Databricks accounts instead of using shared access tokens. This provides better security and audit capabilities.
###### Prerequisites
1. Create an OAuth2 application in your Databricks account:
- Go to your Databricks account console
- Navigate to **Settings** → **Developer** → **OAuth apps**
- Create a new OAuth app with the redirect URI: `http://your-superset-host:port/api/v1/database/oauth2/`
2. Configure OAuth2 in your `superset_config.py`:
```python
from datetime import timedelta
# OAuth2 configuration for Databricks
# The authorization endpoint is derived from your Databricks workspace host; the
# token endpoint must be set explicitly (see notes below).
DATABASE_OAUTH2_CLIENTS = {
"Databricks (legacy)": {
"id": "your-databricks-client-id",
"secret": "your-databricks-client-secret",
"scope": "sql",
"token_request_uri": "https://your-workspace-host/oidc/v1/token",
},
"Databricks": {
"id": "your-databricks-client-id",
"secret": "your-databricks-client-secret",
"scope": "sql",
"token_request_uri": "https://your-workspace-host/oidc/v1/token",
},
}
# OAuth2 redirect URI (adjust hostname/port for your setup)
DATABASE_OAUTH2_REDIRECT_URI = "http://your-superset-host:port/api/v1/database/oauth2/"
# Optional: OAuth2 timeout
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
```
Replace the following placeholders:
- `your-databricks-client-id`: Your Databricks OAuth2 application client ID
- `your-databricks-client-secret`: Your Databricks OAuth2 application client secret
- `your-superset-host:port`: Your Superset instance hostname and port
**Multi-Cloud Provider Support**
Databricks fronts the user-to-machine (U2M) OAuth2 flow on every workspace at
`https://<workspace-host>/oidc/v1/authorize` and
`https://<workspace-host>/oidc/v1/token`, regardless of whether the workspace
runs on AWS, Azure, or GCP. Superset derives the **authorization** endpoint
directly from your connection's host, so no cloud provider or account/tenant
identifier needs to be configured.
The **token** endpoint cannot be auto-derived (token exchange has no database
context to read the host), so you must supply `token_request_uri` in
`DATABASE_OAUTH2_CLIENTS`, set to `https://<workspace-host>/oidc/v1/token` for
your workspace.
If you supply a fully-resolved `authorization_request_uri` (and/or
`token_request_uri`), those values take precedence over the host-derived
defaults.
###### Usage
Once configured, users can:
1. Connect to Databricks databases normally using access tokens
2. When querying data, Superset will automatically redirect users to authenticate with Databricks if needed
3. User-specific OAuth2 tokens will be used for database connections, providing better security and audit trails
This feature works with both "Databricks (legacy)" and "Databricks" engine types and automatically supports all major cloud providers (AWS, Azure, GCP).
#### Denodo
The recommended connector library for Denodo is

View File

@@ -344,16 +344,6 @@ export default function transformProps(
data2,
currencyCodeColumn,
);
const getAxisFormatterConfig = (axisIndex?: number) =>
axisIndex === 1
? {
customFormatters: customFormattersSecondary,
formatter: formatterSecondary,
}
: {
customFormatters,
formatter,
};
const primarySeries = new Set<string>();
const secondarySeries = new Set<string>();
@@ -432,8 +422,6 @@ export default function transformProps(
let [minSecondary, maxSecondary] = (yAxisBoundsSecondary || []).map(
parseAxisBound,
);
const getAxisMax = (axisIndex?: number) =>
axisIndex === 1 ? maxSecondary : yAxisMax;
const array = ensureIsArray(chartProps.rawFormData?.time_compare);
const inverted = invert(verboseMap);
@@ -457,11 +445,10 @@ export default function transformProps(
// When no groupby, format as just the entry name with optional query identifier
displayName = showQueryIdentifiers ? `${entryName} (Query A)` : entryName;
}
const axisFormatterConfig = getAxisFormatterConfig(yAxisIndex);
const seriesFormatter = getFormatter(
axisFormatterConfig.customFormatters,
axisFormatterConfig.formatter,
customFormatters,
formatter,
metrics,
labelMap?.[seriesName]?.[0],
!!contributionMode,
@@ -493,7 +480,7 @@ export default function transformProps(
formatter:
seriesType === EchartsTimeseriesSeriesType.Bar
? getOverMaxHiddenFormatter({
max: getAxisMax(yAxisIndex),
max: yAxisMax,
formatter: seriesFormatter,
})
: seriesFormatter,
@@ -531,11 +518,10 @@ export default function transformProps(
// When no groupby, format as just the entry name with optional query identifier
displayName = showQueryIdentifiers ? `${entryName} (Query B)` : entryName;
}
const axisFormatterConfig = getAxisFormatterConfig(yAxisIndexB);
const seriesFormatter = getFormatter(
axisFormatterConfig.customFormatters,
axisFormatterConfig.formatter,
customFormattersSecondary,
formatterSecondary,
metricsB,
labelMapB?.[seriesName]?.[0],
!!contributionMode,
@@ -568,7 +554,7 @@ export default function transformProps(
formatter:
seriesTypeB === EchartsTimeseriesSeriesType.Bar
? getOverMaxHiddenFormatter({
max: getAxisMax(yAxisIndexB),
max: maxSecondary,
formatter: seriesFormatter,
})
: seriesFormatter,

View File

@@ -35,26 +35,13 @@ import {
} from '../../src';
import transformProps from '../../src/MixedTimeseries/transformProps';
import {
DEFAULT_FORM_DATA,
EchartsMixedTimeseriesFormData,
EchartsMixedTimeseriesProps,
} from '../../src/MixedTimeseries/types';
import { DEFAULT_FORM_DATA } from '../../src/MixedTimeseries/types';
import { createEchartsTimeseriesTestChartProps } from '../helpers';
import type { SeriesOption } from 'echarts';
type LabelFormatterParams = {
value: [number, number];
dataIndex: number;
seriesIndex: number;
seriesName: string;
};
type SeriesWithLabelFormatter = SeriesOption & {
label?: {
formatter?: (params: LabelFormatterParams) => string | number;
};
};
/**
* Creates a partial ChartDataResponseResult for testing.
* Only includes the fields needed for tests, with sensible defaults for required fields.
@@ -161,30 +148,6 @@ const queriesData: ChartDataResponseResult[] = [
createTestQueryData(defaultQueryRows, { label_map: defaultLabelMap }),
];
function getSeriesWithLabelFormatter(
series: SeriesOption[],
name: string,
): SeriesWithLabelFormatter {
const result = series.find(seriesOption => seriesOption.name === name);
expect(result).toBeDefined();
expect((result as SeriesWithLabelFormatter).label?.formatter).toBeDefined();
return result as SeriesWithLabelFormatter;
}
function formatSeriesLabel(
series: SeriesWithLabelFormatter,
value: [number, number],
) {
const formatter = series.label?.formatter;
expect(formatter).toBeDefined();
return formatter?.({
dataIndex: 0,
seriesIndex: 0,
seriesName: String(series.name),
value,
});
}
test('should transform chart props for viz with showQueryIdentifiers=false', () => {
const chartProps = createEchartsTimeseriesTestChartProps<
EchartsMixedTimeseriesFormData,
@@ -269,162 +232,6 @@ test('should transform chart props for viz with showQueryIdentifiers=true', () =
]);
});
test('formats value labels with the formatter for the assigned y-axis', () => {
const timestamp = 1704067200000;
const queryAData = createTestQueryData(
[{ __timestamp: timestamp, lineMetric: 0.25 }],
{
colnames: ['__timestamp', 'lineMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { lineMetric: ['lineMetric'] },
},
);
const queryBData = createTestQueryData(
[{ __timestamp: timestamp, barMetric: 0.5 }],
{
colnames: ['__timestamp', 'barMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { 'barMetric (1)': ['barMetric'] },
},
);
const chartProps = createEchartsTimeseriesTestChartProps<
EchartsMixedTimeseriesFormData,
EchartsMixedTimeseriesProps
>({
...MIXED_TIMESERIES_CHART_PROPS_DEFAULTS,
defaultQueriesData: [queryAData, queryBData],
formData: {
...formData,
groupby: [],
groupbyB: [],
metrics: ['lineMetric'],
metricsB: ['barMetric'],
showValue: true,
showValueB: true,
stack: null,
stackB: null,
x_axis: '__timestamp',
yAxisFormat: '.0%',
yAxisFormatSecondary: ',.1f',
yAxisIndex: 1,
yAxisIndexB: 0,
},
queriesData: [queryAData, queryBData],
});
const { echartOptions } = transformProps(chartProps);
const series = echartOptions.series as SeriesOption[];
const lineSeries = getSeriesWithLabelFormatter(series, 'lineMetric');
const barSeries = getSeriesWithLabelFormatter(series, 'barMetric');
expect(formatSeriesLabel(lineSeries, [timestamp, 0.25])).toBe('0.3');
expect(formatSeriesLabel(barSeries, [timestamp, 0.5])).toBe('50%');
});
test('formats value labels correctly when y-axis assignments are reversed', () => {
const timestamp = 1704067200000;
const queryAData = createTestQueryData(
[{ __timestamp: timestamp, lineMetric: 0.25 }],
{
colnames: ['__timestamp', 'lineMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { lineMetric: ['lineMetric'] },
},
);
const queryBData = createTestQueryData(
[{ __timestamp: timestamp, barMetric: 0.5 }],
{
colnames: ['__timestamp', 'barMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { 'barMetric (1)': ['barMetric'] },
},
);
const chartProps = createEchartsTimeseriesTestChartProps<
EchartsMixedTimeseriesFormData,
EchartsMixedTimeseriesProps
>({
...MIXED_TIMESERIES_CHART_PROPS_DEFAULTS,
defaultQueriesData: [queryAData, queryBData],
formData: {
...formData,
groupby: [],
groupbyB: [],
metrics: ['lineMetric'],
metricsB: ['barMetric'],
showValue: true,
showValueB: true,
stack: null,
stackB: null,
x_axis: '__timestamp',
yAxisFormat: '.0%',
yAxisFormatSecondary: ',.1f',
yAxisIndex: 0,
yAxisIndexB: 1,
},
queriesData: [queryAData, queryBData],
});
const { echartOptions } = transformProps(chartProps);
const series = echartOptions.series as SeriesOption[];
const lineSeries = getSeriesWithLabelFormatter(series, 'lineMetric');
const barSeries = getSeriesWithLabelFormatter(series, 'barMetric');
expect(formatSeriesLabel(lineSeries, [timestamp, 0.25])).toBe('25%');
expect(formatSeriesLabel(barSeries, [timestamp, 0.5])).toBe('0.5');
});
test('keeps bar value label clipping aligned with the assigned y-axis', () => {
const timestamp = 1704067200000;
const queryAData = createTestQueryData(
[{ __timestamp: timestamp, lineMetric: 0.25 }],
{
colnames: ['__timestamp', 'lineMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { lineMetric: ['lineMetric'] },
},
);
const queryBData = createTestQueryData(
[{ __timestamp: timestamp, barMetric: 0.5 }],
{
colnames: ['__timestamp', 'barMetric'],
coltypes: [GenericDataType.Temporal, GenericDataType.Numeric],
label_map: { 'barMetric (1)': ['barMetric'] },
},
);
const chartProps = createEchartsTimeseriesTestChartProps<
EchartsMixedTimeseriesFormData,
EchartsMixedTimeseriesProps
>({
...MIXED_TIMESERIES_CHART_PROPS_DEFAULTS,
defaultQueriesData: [queryAData, queryBData],
formData: {
...formData,
groupby: [],
groupbyB: [],
metrics: ['lineMetric'],
metricsB: ['barMetric'],
showValue: true,
showValueB: true,
stack: null,
stackB: null,
x_axis: '__timestamp',
yAxisBounds: [undefined, 1],
yAxisBoundsSecondary: [undefined, 0.1],
yAxisFormat: '.0%',
yAxisFormatSecondary: ',.1f',
yAxisIndex: 0,
yAxisIndexB: 1,
},
queriesData: [queryAData, queryBData],
});
const { echartOptions } = transformProps(chartProps);
const series = echartOptions.series as SeriesOption[];
const barSeries = getSeriesWithLabelFormatter(series, 'barMetric');
expect(formatSeriesLabel(barSeries, [timestamp, 0.5])).toBe('');
});
describe('legend sorting', () => {
const getChartProps = (overrides = {}) =>
createEchartsTimeseriesTestChartProps<

View File

@@ -227,7 +227,10 @@ export const buildQuery: BuildQuery<TableChartFormData> = (
formData?.result_type === 'results');
if (isDownloadQuery) {
moreProps.row_limit = Number(formDataCopy.row_limit) || 0;
moreProps.row_limit =
formDataCopy.row_limit != null
? Number(formDataCopy.row_limit)
: undefined;
moreProps.row_offset = 0;
}

View File

@@ -44,13 +44,13 @@ const fakeTableApiResult = {
result: [
{
id: 1,
value: 'fake_api_result1',
value: 'fake api result1',
label: 'fake api label1',
type: 'table',
},
{
id: 2,
value: 'fake_api_result2',
value: 'fake api result2',
label: 'fake api label2',
type: 'table',
},
@@ -152,64 +152,6 @@ test('returns keywords including fetched function_names data', async () => {
});
});
test('quotes table identifiers that require quoting in the inserted value', async () => {
const dbFunctionNamesApiRoute = `glob:*/api/v1/database/${expectDbId}/function_names/`;
fetchMock.get(dbFunctionNamesApiRoute, fakeFunctionNamesApiResult);
act(() => {
store.dispatch(
tableApiUtil.upsertQueryData(
'tables',
{ dbId: expectDbId, schema: expectSchema },
{
options: [
{ value: 'COVID Vaccines', label: 'COVID Vaccines', type: 'table' },
{ value: 'simple_table', label: 'simple_table', type: 'table' },
],
hasMore: false,
},
),
);
});
const { result } = renderHook(
() =>
useKeywords({
queryEditorId: 'testqueryid',
dbId: expectDbId,
schema: expectSchema,
}),
{
wrapper: createWrapper({
useRedux: true,
store,
}),
},
);
await waitFor(() =>
expect(fetchMock.callHistory.calls(dbFunctionNamesApiRoute).length).toBe(1),
);
// A name that needs quoting is inserted as a double-quoted identifier,
// while its display name stays human-readable.
expect(result.current).toContainEqual(
expect.objectContaining({
name: 'COVID Vaccines',
value: '"COVID Vaccines"',
meta: 'table',
}),
);
// A simple identifier is inserted as-is, without quotes.
expect(result.current).toContainEqual(
expect.objectContaining({
name: 'simple_table',
value: 'simple_table',
meta: 'table',
}),
);
});
test('skip fetching if autocomplete skipped', () => {
const { result } = renderHook(
() =>

View File

@@ -53,14 +53,6 @@ const getHelperText = (value: string) =>
detail: value,
};
// Names that aren't simple identifiers (spaces, punctuation, leading digits)
// must be double-quoted to be valid SQL, with embedded quotes doubled.
const SIMPLE_IDENTIFIER_RE = /^[A-Za-z_][A-Za-z0-9_]*$/;
const quoteIdentifier = (identifier: string) =>
SIMPLE_IDENTIFIER_RE.test(identifier)
? identifier
: `"${identifier.replace(/"/g, '""')}"`;
const extensionsRegistry = getExtensionsRegistry();
export function useKeywords(
@@ -205,7 +197,7 @@ export function useKeywords(
() =>
allCachedTables.map(({ value, label, schema: tableSchema }) => ({
name: label,
value: quoteIdentifier(value),
value,
schema: tableSchema,
score: TABLE_AUTOCOMPLETE_SCORE,
meta: 'table',

View File

@@ -562,7 +562,6 @@ const Chart = (props: ChartProps) => {
exportFormData as unknown as import('@superset-ui/core').QueryFormData,
resultType,
resultFormat: format,
force: true,
ownState: exportOwnState,
onStartStreamingExport: shouldUseStreaming
? (exportParams: JsonObject) => {

View File

@@ -100,10 +100,7 @@ class CacheRestApi(BaseSupersetModelRestApi):
)
cache_keys = [c.cache_key for c in cache_key_objs]
if cache_key_objs:
# Chart query results live in ``data_cache``, not the default
# ``cache`` — using the wrong backend silently misses the Redis
# keys when ``CACHE_KEY_PREFIX`` differs between the two configs.
all_keys_deleted = cache_manager.data_cache.delete_many(*cache_keys)
all_keys_deleted = cache_manager.cache.delete_many(*cache_keys)
if not all_keys_deleted:
# expected behavior as keys may expire and cache is not a

View File

@@ -23,7 +23,7 @@ import sys
import urllib
from datetime import datetime
from re import Pattern
from typing import Any, Callable, TYPE_CHECKING, TypedDict
from typing import Any, TYPE_CHECKING, TypedDict
import pandas as pd
from apispec import APISpec
@@ -83,97 +83,6 @@ if TYPE_CHECKING:
logger = logging.getLogger()
# BigQuery string escape sequences keyed off documented escapes in
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#string_and_bytes_literals.
# Backslash MUST be first so subsequent escapes don't double-escape their own
# backslash. ``\?``, ``\"`` and ``\``` are valid BigQuery escapes but
# intentionally omitted because those characters do not require escaping
# inside a single-quoted literal. ``\0`` is NOT a valid BigQuery escape
# (octal escapes require exactly three digits); the null byte instead falls
# through to the ``\xhh`` fallback below.
_BIGQUERY_STRING_ESCAPES = {
"\\": "\\\\",
"'": "\\'",
"\n": "\\n",
"\r": "\\r",
"\t": "\\t",
"\b": "\\b",
"\f": "\\f",
"\v": "\\v",
"\a": "\\a",
}
def _process_string_literal(value: str) -> str:
"""
Escape a string value for use as a BigQuery SQL literal.
BigQuery requires backslash escaping for single quotes inside string
literals (``'O\\'Brien'``). Doubled single quotes (``'O''Brien'``) are
**not** valid — BigQuery parses them as two concatenated string literals
without whitespace, causing a syntax error:
``concatenated string literals must be separated by whitespace``.
BigQuery also forbids literal newlines, carriage returns, and other
control characters inside a quoted string; those must be written using
escape sequences (``\\n``, ``\\r``, ``\\t`` …). Control characters
without a named escape are emitted as a ``\\xhh`` hex escape; printable
Unicode passes through unchanged because BigQuery accepts UTF-8 inside
string literals.
The upstream ``sqlalchemy-bigquery`` dialect relies on Python's ``repr()``
to quote values, which switches to double-quote delimiters when the
string contains an apostrophe (e.g. ``repr("O'Brien")`` → ``"O'Brien"``).
Double-quoted tokens inside compiled SQL would be parsed as identifiers,
so the query also fails. This helper always produces a single-quoted
literal.
"""
parts = []
for ch in value:
escape = _BIGQUERY_STRING_ESCAPES.get(ch)
if escape is not None:
parts.append(escape)
elif ord(ch) < 0x20 or ord(ch) == 0x7F:
parts.append(f"\\x{ord(ch):02x}")
else:
parts.append(ch)
return f"'{''.join(parts)}'"
def _monkeypatch_bigquery_string_literal() -> None:
"""
Patch the sqlalchemy-bigquery dialect so that string literals containing
apostrophes are rendered correctly when ``literal_binds=True``.
Without this patch, a filter value like ``O'Brien`` is compiled as the
double-quoted identifier ``"O'Brien"`` instead of the single-quoted literal
``'O\\'Brien'``, causing BigQuery to return a syntax error.
This follows the same pattern used for the Databricks dialect fix in
``superset/db_engine_specs/databricks.py``.
"""
try:
from sqlalchemy_bigquery import BigQueryDialect
class BigQuerySafeString(types.TypeDecorator):
impl = types.String
cache_ok = True
def literal_processor(self, dialect: Any) -> Callable[[str], str]:
if dialect.name == "bigquery":
return _process_string_literal
return super().literal_processor(dialect)
BigQueryDialect.colspecs[types.String] = BigQuerySafeString
except ImportError:
pass
_monkeypatch_bigquery_string_literal()
CONNECTION_DATABASE_PERMISSIONS_REGEX = re.compile(
"Access Denied: Project (?P<project_name>.+?): User does not have "
+ "bigquery.jobs.create permission in project (?P<project>.+?)"

View File

@@ -17,11 +17,10 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Callable, cast, TYPE_CHECKING, TypedDict, Union
from typing import Any, Callable, TYPE_CHECKING, TypedDict, Union
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from flask import g
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.validate import Range
@@ -39,18 +38,12 @@ from superset.db_engine_specs.base import (
)
from superset.db_engine_specs.hive import HiveEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error
from superset.utils import json
from superset.utils.core import get_user_agent, QuerySource
from superset.utils.network import is_hostname_valid, is_port_open
if TYPE_CHECKING:
from superset.models.core import Database
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
OAuth2TokenResponse,
)
try:
@@ -284,135 +277,6 @@ class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngine
"port": "port",
}
# The Databricks SQL driver has no dedicated authentication exception, so an
# expired or missing token surfaces as a generic driver error. These case-
# insensitive substrings flag the errors that should bootstrap a re-auth.
oauth2_auth_failure_signals = (
"http 401",
"unauthorized",
"unauthenticated",
"invalid access token",
"invalid token",
"expired token",
"token expired",
)
@classmethod
def _workspace_oauth2_endpoint(cls, database: Database, path: str) -> str:
"""
Build a Databricks OAuth2 (U2M) endpoint from the workspace host.
Databricks fronts the user-to-machine OAuth2 flow on every workspace at
``https://<workspace-host>/oidc/v1/{authorize,token}`` across AWS, Azure
and GCP, so the endpoints derive directly from the connection host and
need no account or tenant identifier.
"""
host = database.url_object.host
if not host:
raise OAuth2Error(
"Databricks OAuth2 endpoint could not be resolved: the database "
"connection has no host."
)
return f"https://{host}/oidc/v1/{path}"
@classmethod
def needs_oauth2(cls, ex: Exception) -> bool:
"""
Identify driver errors that should trigger the OAuth2 dance.
Unlike Trino (``TrinoAuthError``) or GSheets (``UnauthenticatedError``),
the Databricks driver raises no dedicated auth exception, so in addition
to the base ``isinstance`` check we match the auth signals above on the
error message (mirrors ``GSheetsEngineSpec.needs_oauth2``).
"""
if not (g and hasattr(g, "user")):
return False
if isinstance(ex, cls.oauth2_exception):
return True
message = str(ex).lower()
return any(signal in message for signal in cls.oauth2_auth_failure_signals)
@classmethod
def get_oauth2_authorization_uri(
cls,
config: "OAuth2ClientConfig",
state: "OAuth2State",
code_verifier: str | None = None,
) -> str:
"""
Return the URI for the initial OAuth2 request.
A fully-resolved ``authorization_request_uri`` from
``DATABASE_OAUTH2_CLIENTS`` is preserved; otherwise the endpoint is
derived from the workspace host (``https://<host>/oidc/v1/authorize``),
which is valid on AWS, Azure and GCP.
"""
if not config.get("authorization_request_uri"):
from superset import db
from superset.models.core import Database
database_id = state["database_id"]
if database := db.session.get(Database, database_id):
config = cast(
"OAuth2ClientConfig",
dict(config)
| {
"authorization_request_uri": cls._workspace_oauth2_endpoint(
database, "authorize"
)
},
)
return super().get_oauth2_authorization_uri(config, state, code_verifier)
@classmethod
def get_oauth2_token(
cls,
config: "OAuth2ClientConfig",
code: str,
code_verifier: str | None = None,
) -> "OAuth2TokenResponse":
"""
Exchange the authorization code for refresh/access tokens.
Token exchange runs in a separate request with no database context, so
the workspace host is not available to derive the endpoint here. Require
a configured ``token_request_uri``
(``https://<workspace-host>/oidc/v1/token``) and fail fast rather than
POST to an unresolved endpoint.
"""
if not config.get("token_request_uri"):
raise OAuth2Error(
"Databricks OAuth2 token endpoint is not configured: set "
"`token_request_uri` to https://<workspace-host>/oidc/v1/token "
"in DATABASE_OAUTH2_CLIENTS."
)
return super().get_oauth2_token(config, code, code_verifier)
@classmethod
def impersonate_user(
cls,
database: Database,
username: str | None,
user_token: str | None,
url: URL,
engine_kwargs: dict[str, Any],
) -> tuple[URL, dict[str, Any]]:
"""
Update connection with OAuth2 access token for user impersonation.
"""
if user_token:
# Replace the access token in the URL with the user's OAuth2 token
url = url.set(password=user_token)
# Also update connect_args if they contain access token
connect_args = engine_kwargs.setdefault("connect_args", {})
if "access_token" in connect_args:
connect_args["access_token"] = user_token
return url, engine_kwargs
@staticmethod
def get_extra_params(
database: Database, source: QuerySource | None = None
@@ -610,16 +474,6 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
supports_dynamic_catalog = True
supports_cross_catalog_queries = True
# OAuth 2.0 support. The flow (endpoint resolution from the workspace host,
# `needs_oauth2` detection) is shared via `DatabricksDynamicBaseEngineSpec`.
supports_oauth2 = True
oauth2_scope = "sql"
# Authorization endpoint is derived from the workspace host at runtime; the
# token endpoint must be configured (no DB context at exchange time).
oauth2_authorization_request_uri = ""
oauth2_token_request_uri = ""
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksNativeParametersType, *_
@@ -831,16 +685,6 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
# OAuth 2.0 support. The flow (endpoint resolution from the workspace host,
# `needs_oauth2` detection) is shared via `DatabricksDynamicBaseEngineSpec`.
supports_oauth2 = True
oauth2_scope = "sql"
# Authorization endpoint is derived from the workspace host at runtime; the
# token endpoint must be configured (no DB context at exchange time).
oauth2_authorization_request_uri = ""
oauth2_token_request_uri = ""
@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksPythonConnectorParametersType, *_

View File

@@ -303,7 +303,6 @@ DEFAULT_GET_DASHBOARD_INFO_COLUMNS: List[str] = [
"created_on",
"changed_on",
"uuid",
"embedded_uuid",
"url",
"created_on_humanized",
"changed_on_humanized",
@@ -428,18 +427,6 @@ class DashboardInfo(BaseModel):
created_on: str | datetime | None = None
changed_on: str | datetime | None = None
uuid: str | None = None
embedded_uuid: str | None = Field(
None,
description=(
"Embedded UUID for this dashboard. This is the UUID required when "
"generating guest tokens for embedded dashboards "
"(resources[].id in the guest token payload). "
"Only present when the dashboard has been configured for embedding "
"via the Embed Dashboard UI. Distinct from `uuid` (the internal "
"dashboard UUID) — using the wrong one causes 403 errors in guest "
"token validation."
),
)
url: str | None = None
created_on_humanized: str | None = None
changed_on_humanized: str | None = None
@@ -1365,9 +1352,6 @@ def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
created_on=dashboard.created_on,
changed_on=dashboard.changed_on,
uuid=str(dashboard.uuid) if dashboard.uuid else None,
embedded_uuid=str(dashboard.embedded[0].uuid)
if dashboard.embedded
else None,
url=absolute_url,
created_on_humanized=dashboard.created_on_humanized,
changed_on_humanized=dashboard.changed_on_humanized,

View File

@@ -155,11 +155,10 @@ async def get_dashboard_info(
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
# Eager load slices, tags, and embedded to avoid N+1 queries.
# Eager load slices and tags to avoid N+1 queries during serialization.
eager_options = [
subqueryload(Dashboard.slices).subqueryload(Slice.tags),
subqueryload(Dashboard.tags),
subqueryload(Dashboard.embedded),
]
with event_logger.log_context(action="mcp.get_dashboard_info.lookup"):

View File

@@ -18,7 +18,6 @@
"""Unit tests for Superset"""
from typing import Any
from unittest.mock import patch
import pytest
@@ -53,43 +52,17 @@ def test_invalidate_cache(invalidate):
def test_invalidate_existing_cache(invalidate):
db.session.add(CacheKey(cache_key="cache_key", datasource_uid="3__table"))
db.session.commit()
cache_manager.data_cache.set("cache_key", "value")
cache_manager.cache.set("cache_key", "value")
rv = invalidate({"datasource_uids": ["3__table"]})
assert rv.status_code == 201
assert cache_manager.data_cache.get("cache_key") is None # noqa: E711
assert cache_manager.cache.get("cache_key") is None # noqa: E711
assert (
not db.session.query(CacheKey).filter(CacheKey.cache_key == "cache_key").first()
)
def test_invalidate_uses_data_cache_not_default_cache(invalidate):
"""Regression test for #40489.
Chart query results are written through ``cache_manager.data_cache``
(``DATA_CACHE_CONFIG``). When ``CACHE_CONFIG`` and ``DATA_CACHE_CONFIG``
use distinct ``CACHE_KEY_PREFIX`` values, deleting via the default
``cache_manager.cache`` silently misses the underlying Redis keys
because flask-caching prepends the wrong prefix to the DEL call.
"""
db.session.add(CacheKey(cache_key="cache_key", datasource_uid="3__table"))
db.session.commit()
with (
patch.object(cache_manager.data_cache, "delete_many") as data_delete,
patch.object(cache_manager.cache, "delete_many") as default_delete,
):
data_delete.return_value = True
rv = invalidate({"datasource_uids": ["3__table"]})
assert rv.status_code == 201
# Chart-data cache backend (the one that wrote the keys) must be hit.
data_delete.assert_called_once_with("cache_key")
# The default cache must NOT be touched — that's the #40489 regression.
default_delete.assert_not_called()
def test_invalidate_cache_empty_input(invalidate):
rv = invalidate({"datasource_uids": []})
assert rv.status_code == 201
@@ -138,10 +111,10 @@ def test_invalidate_existing_caches(invalidate):
db.session.add(CacheKey(cache_key="cache_keyX", datasource_uid="X__table"))
db.session.commit()
cache_manager.data_cache.set("cache_key1", "value")
cache_manager.data_cache.set("cache_key2", "value")
cache_manager.data_cache.set("cache_key4", "value")
cache_manager.data_cache.set("cache_keyX", "value")
cache_manager.cache.set("cache_key1", "value")
cache_manager.cache.set("cache_key2", "value")
cache_manager.cache.set("cache_key4", "value")
cache_manager.cache.set("cache_keyX", "value")
rv = invalidate(
{
@@ -182,10 +155,10 @@ def test_invalidate_existing_caches(invalidate):
)
assert rv.status_code == 201
assert cache_manager.data_cache.get("cache_key1") is None
assert cache_manager.data_cache.get("cache_key2") is None
assert cache_manager.data_cache.get("cache_key4") is None
assert cache_manager.data_cache.get("cache_keyX") == "value"
assert cache_manager.cache.get("cache_key1") is None
assert cache_manager.cache.get("cache_key2") is None
assert cache_manager.cache.get("cache_key4") is None
assert cache_manager.cache.get("cache_keyX") == "value"
assert (
not db.session.query(CacheKey)
.filter(CacheKey.cache_key.in_({"cache_key1", "cache_key2", "cache_key4"}))

View File

@@ -767,265 +767,3 @@ def test_fetch_data_converts_bigquery_row_objects(mocker: MockerFixture) -> None
assert result == [(1, "a"), (2, "b")]
assert flask_g.bq_memory_limited is False
def test_string_literal_with_apostrophe() -> None:
"""
Test that string literals containing apostrophes are properly escaped
for BigQuery using backslash escaping.
BigQuery requires backslash escaping for single quotes ('O\\'Brien').
Doubled single quotes ('O''Brien') are NOT valid — BigQuery parses them
as two concatenated string literals, causing a syntax error.
The upstream sqlalchemy-bigquery dialect uses ``repr()`` which switches
to double-quote delimiters when the value contains an apostrophe.
Double-quoted tokens are identifiers in BigQuery, causing syntax errors.
"""
from sqlalchemy import column as sa_column
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
# Trigger module load to ensure the monkey-patch is applied
assert BigQueryEngineSpec is not None
dialect = BigQueryDialect()
stmt = select(sa_column("name")).where(sa_column("name") == "Fernando's")
compiled_sql = str(
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)
# The compiled SQL must use single-quoted literal with backslash-escaped
# apostrophes. Doubled single quotes are NOT valid in BigQuery.
assert "= 'Fernando\\'s'" in compiled_sql
# Must NOT contain doubled-quote escaping (BigQuery rejects this)
assert "''" not in compiled_sql
# Must NOT contain double-quoted identifiers
assert '\\"' not in compiled_sql
def test_string_literal_without_apostrophe() -> None:
"""
Test that normal string literals (without apostrophes) still compile
correctly after the monkey-patch.
"""
from sqlalchemy import column as sa_column
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
assert BigQueryEngineSpec is not None
dialect = BigQueryDialect()
stmt = select(sa_column("name")).where(sa_column("name") == "Fernando")
compiled_sql = str(
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)
assert "= 'Fernando'" in compiled_sql
def test_string_literal_in_filter_with_apostrophe() -> None:
"""
Test that IN filters with apostrophes in values compile correctly
using backslash escaping.
"""
from sqlalchemy import column as sa_column
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
assert BigQueryEngineSpec is not None
dialect = BigQueryDialect()
stmt = select(sa_column("name")).where(
sa_column("name").in_(["Fernando's", "O'Brien"])
)
compiled_sql = str(
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)
assert "'Fernando\\'s'" in compiled_sql
assert "'O\\'Brien'" in compiled_sql
# Must NOT contain doubled-quote escaping
assert "''" not in compiled_sql
def test_process_string_literal_directly() -> None:
"""
Test _process_string_literal covers backslash escaping for apostrophes,
control-character escaping (newline/CR/tab/etc.), the ``\\xhh`` fallback
for control chars without a named escape, and pass-through for printable
Unicode and other characters BigQuery accepts unescaped.
"""
from superset.db_engine_specs.bigquery import _process_string_literal
# Plain values
assert _process_string_literal("hello") == "'hello'"
assert _process_string_literal("") == "''"
# Apostrophes (the original fix)
assert _process_string_literal("O'Brien") == "'O\\'Brien'"
assert _process_string_literal("it's a test") == "'it\\'s a test'"
# Backslashes must be escaped before apostrophes
assert _process_string_literal("C:\\path") == "'C:\\\\path'"
assert _process_string_literal("it's C:\\path") == "'it\\'s C:\\\\path'"
# Literal backslash followed by 'n' (two characters, not a newline)
# must produce the two-char sequence '\\n' (escaped backslash + n) so
# BigQuery does not misread it as a newline escape.
assert _process_string_literal("\\n") == "'\\\\n'"
# Control characters must be escaped using named escapes — BigQuery
# rejects literal control characters inside quoted strings.
assert _process_string_literal("foo\nbar") == "'foo\\nbar'"
assert _process_string_literal("foo\rbar") == "'foo\\rbar'"
assert _process_string_literal("foo\tbar") == "'foo\\tbar'"
assert _process_string_literal("a\bb\fc\vd\ae") == "'a\\bb\\fc\\vd\\ae'"
# Control characters without a named escape fall through to ``\\xhh``.
assert _process_string_literal("null\0byte") == "'null\\x00byte'"
assert _process_string_literal("a\x01b") == "'a\\x01b'"
assert _process_string_literal("a\x1bb") == "'a\\x1bb'"
assert _process_string_literal("a\x7fb") == "'a\\x7fb'"
# Double quotes do NOT need escaping in single-quoted BigQuery literals.
assert _process_string_literal('say "hello"') == "'say \"hello\"'"
# Printable Unicode and percent signs pass through unchanged.
assert _process_string_literal("café") == "'café'"
assert _process_string_literal("日本") == "'日本'"
assert _process_string_literal("100%") == "'100%'"
# Combined: apostrophe + newline + backslash + unicode.
assert _process_string_literal("it's\nC:\\café") == "'it\\'s\\nC:\\\\café'"
def test_process_string_literal_no_literal_control_chars() -> None:
"""
Regression test for the issue raised in PR #38835 review: BigQuery
rejects literal control characters inside quoted string literals, so the
output must never contain them as literal characters.
"""
from superset.db_engine_specs.bigquery import _process_string_literal
for char in ["\n", "\r", "\t", "\b", "\f", "\v", "\a", "\0", "\x01", "\x7f"]:
result = _process_string_literal(f"prefix{char}suffix")
assert char not in result, (
f"Literal {char!r} leaked into output {result!r}; "
"BigQuery would reject this literal."
)
def test_string_literal_with_newline_in_filter() -> None:
"""
End-to-end regression test for @rusackas's review feedback on PR #38835:
a filter value containing a newline must compile to valid BigQuery SQL
using the ``\\n`` escape sequence, not a literal newline.
"""
from sqlalchemy import column as sa_column
from superset.db_engine_specs.bigquery import BigQueryEngineSpec # noqa: F811
assert BigQueryEngineSpec is not None
dialect = BigQueryDialect()
stmt = select(sa_column("note")).where(sa_column("note") == "line1\nline2")
compiled_sql = str(
stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)
# Must use the escape sequence form, not a literal newline.
assert "'line1\\nline2'" in compiled_sql
assert "\n" not in compiled_sql.split("note")[-1]
def test_literal_processor_non_bigquery_dialect() -> None:
"""
Test that BigQuerySafeString.literal_processor falls back to the parent
implementation when used with a non-BigQuery dialect.
"""
from sqlalchemy import create_engine
from superset.db_engine_specs.bigquery import (
_monkeypatch_bigquery_string_literal, # noqa: F811
)
_monkeypatch_bigquery_string_literal()
safe_cls = BigQueryDialect.colspecs[sqltypes.String]
instance = safe_cls()
# Use a non-BigQuery dialect (sqlite)
sqlite_dialect = create_engine("sqlite://").dialect
processor = instance.literal_processor(sqlite_dialect)
# The fallback processor should still produce a valid quoted string
assert processor is not None
def test_monkeypatch_is_applied() -> None:
"""
Test that _monkeypatch_bigquery_string_literal installs the custom
type decorator into BigQueryDialect.colspecs.
"""
from sqlalchemy.sql import sqltypes as sa_sqltypes
from superset.db_engine_specs.bigquery import (
BigQueryEngineSpec, # noqa: F811
)
assert BigQueryEngineSpec is not None
colspecs = BigQueryDialect.colspecs
assert sa_sqltypes.String in colspecs
safe_cls = colspecs[sa_sqltypes.String]
assert safe_cls.__name__ == "BigQuerySafeString"
def test_literal_processor_returns_process_string_literal_for_bigquery() -> None:
"""
Test that BigQuerySafeString.literal_processor returns the
_process_string_literal function when given a BigQuery dialect,
and that calling it produces correctly escaped output.
"""
from superset.db_engine_specs.bigquery import (
_monkeypatch_bigquery_string_literal,
_process_string_literal,
)
_monkeypatch_bigquery_string_literal()
safe_cls = BigQueryDialect.colspecs[sqltypes.String]
instance = safe_cls()
dialect = BigQueryDialect()
processor = instance.literal_processor(dialect)
assert processor is _process_string_literal
assert processor("O'Brien") == "'O\\'Brien'"
assert processor("plain") == "'plain'"
def test_monkeypatch_handles_missing_bigquery_package() -> None:
"""
Test that _monkeypatch_bigquery_string_literal gracefully handles
the case where sqlalchemy_bigquery is not installed.
"""
import builtins
from superset.db_engine_specs.bigquery import (
_monkeypatch_bigquery_string_literal,
)
original_import = builtins.__import__
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "sqlalchemy_bigquery":
raise ImportError("mocked missing package")
return original_import(name, *args, **kwargs)
with mock.patch("builtins.__import__", side_effect=mock_import):
# Should not raise — the except ImportError branch handles it
_monkeypatch_bigquery_string_literal()

View File

@@ -17,23 +17,14 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from datetime import datetime
from typing import Any, Optional
from urllib.parse import parse_qs, urlparse
from typing import Optional
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.engine.url import make_url
from superset.db_engine_specs.base import OAuth2State
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
DatabricksPythonConnectorEngineSpec,
)
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.superset_typing import OAuth2ClientConfig
from superset.utils import json
from superset.utils.oauth2 import decode_oauth2_state
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm # noqa: F401
@@ -300,595 +291,3 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
"USE CATALOG `evil`` USE CATALOG bad`",
"USE SCHEMA `evil`` USE SCHEMA bad`",
]
# OAuth2 Tests
def test_oauth2_attributes() -> None:
"""
Test that OAuth2 attributes are properly set for both engine specs.
"""
# Test DatabricksNativeEngineSpec
assert DatabricksNativeEngineSpec.supports_oauth2 is True
assert DatabricksNativeEngineSpec.oauth2_scope == "sql"
# The authorization endpoint is derived from the workspace host at runtime;
# the token endpoint must be configured explicitly.
assert DatabricksNativeEngineSpec.oauth2_authorization_request_uri == ""
assert DatabricksNativeEngineSpec.oauth2_token_request_uri == ""
# Test DatabricksPythonConnectorEngineSpec
assert DatabricksPythonConnectorEngineSpec.supports_oauth2 is True
assert DatabricksPythonConnectorEngineSpec.oauth2_scope == "sql"
assert DatabricksPythonConnectorEngineSpec.oauth2_authorization_request_uri == ""
assert DatabricksPythonConnectorEngineSpec.oauth2_token_request_uri == ""
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
@pytest.mark.parametrize(
"message",
[
"Error during request to server: HTTP 401 Unauthorized",
"Invalid access token",
"The access token expired",
"UNAUTHENTICATED: token is no longer valid",
],
)
def test_needs_oauth2_detects_auth_failure_from_message(
mocker: MockerFixture,
spec: Any,
message: str,
) -> None:
"""
The Databricks driver has no dedicated auth exception, so `needs_oauth2`
matches auth-failure signals in the error message to bootstrap a re-auth.
"""
g = mocker.patch("superset.db_engine_specs.databricks.g")
g.user = mocker.MagicMock()
assert spec.needs_oauth2(Exception(message)) is True
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
@pytest.mark.parametrize(
"message",
[
"Table not found",
# A bare 401 in an unrelated position must not look like an auth failure.
"Query failed at line 401: syntax error",
],
)
def test_needs_oauth2_ignores_unrelated_errors(
mocker: MockerFixture,
spec: Any,
message: str,
) -> None:
"""
A non-auth driver error must not trigger the OAuth2 dance.
"""
g = mocker.patch("superset.db_engine_specs.databricks.g")
g.user = mocker.MagicMock()
assert spec.needs_oauth2(Exception(message)) is False
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_needs_oauth2_matches_oauth2_redirect_error(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
The inherited `isinstance` check against `oauth2_exception` still holds.
"""
g = mocker.patch("superset.db_engine_specs.databricks.g")
g.user = mocker.MagicMock()
ex = OAuth2RedirectError("https://example/authorize", "tab", "redirect")
assert spec.needs_oauth2(ex) is True
def test_impersonate_user_with_token(mocker: MockerFixture) -> None:
"""
Test impersonate_user method with OAuth2 token for DatabricksNativeEngineSpec.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks+connector://token:original-token@host:443/database"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test with user token
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
database=database,
username="user1",
user_token="user-oauth-token", # noqa: S106
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that the password (token) was updated in the URL
assert url.password == "user-oauth-token" # noqa: S105
# Check that access_token was updated in connect_args
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
def test_impersonate_user_without_token(mocker: MockerFixture) -> None:
"""
Test impersonate_user method without OAuth2 token.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks+connector://token:original-token@host:443/database"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test without user token
url, kwargs = DatabricksNativeEngineSpec.impersonate_user(
database=database,
username="user1",
user_token=None,
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that nothing was changed
assert url.password == "original-token" # noqa: S105
assert kwargs["connect_args"]["access_token"] == "original-token" # noqa: S105
def test_impersonate_user_python_connector(mocker: MockerFixture) -> None:
"""
Test impersonate_user method for DatabricksPythonConnectorEngineSpec.
"""
database = mocker.MagicMock()
original_url = make_url(
"databricks://token:original-token@host:443?http_path=path&catalog=main&schema=default"
)
engine_kwargs = {"connect_args": {"access_token": "original-token"}}
# Test with user token
url, kwargs = DatabricksPythonConnectorEngineSpec.impersonate_user(
database=database,
username="user1",
user_token="user-oauth-token", # noqa: S106
url=original_url,
engine_kwargs=engine_kwargs,
)
# Check that the password (token) was updated in the URL
assert url.password == "user-oauth-token" # noqa: S105
# Check that access_token was updated in connect_args
assert kwargs["connect_args"]["access_token"] == "user-oauth-token" # noqa: S105
@pytest.fixture
def oauth2_config_native() -> OAuth2ClientConfig:
"""
Config for Databricks Native OAuth2.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
"request_content_type": "json",
}
@pytest.fixture
def oauth2_config_python() -> OAuth2ClientConfig:
"""
Config for Databricks Python Connector OAuth2.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/authorize",
"token_request_uri": "https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
"request_content_type": "json",
}
def test_is_oauth2_enabled_no_config_native(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is not configured for Native engine.
"""
mocker.patch(
"flask.current_app.config",
new={"DATABASE_OAUTH2_CLIENTS": {}},
)
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is False
def test_is_oauth2_enabled_config_native(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is configured for Native engine.
"""
mocker.patch(
"flask.current_app.config",
new={
"DATABASE_OAUTH2_CLIENTS": {
"Databricks (legacy)": {
"id": "client-id",
"secret": "client-secret",
},
}
},
)
assert DatabricksNativeEngineSpec.is_oauth2_enabled() is True
def test_is_oauth2_enabled_no_config_python(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is not configured for Python Connector engine.
"""
mocker.patch(
"flask.current_app.config",
new={"DATABASE_OAUTH2_CLIENTS": {}},
)
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is False
def test_is_oauth2_enabled_config_python(mocker: MockerFixture) -> None:
"""
Test `is_oauth2_enabled` when OAuth2 is configured for Python Connector engine.
"""
mocker.patch(
"flask.current_app.config",
new={
"DATABASE_OAUTH2_CLIENTS": {
"Databricks": {
"id": "client-id",
"secret": "client-secret",
},
}
},
)
assert DatabricksPythonConnectorEngineSpec.is_oauth2_enabled() is True
def test_get_oauth2_authorization_uri_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_authorization_uri` for Native engine.
"""
from superset.db_engine_specs.base import OAuth2State
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksNativeEngineSpec.get_oauth2_authorization_uri(
oauth2_config_native, state
)
parsed = urlparse(url)
assert parsed.netloc == "accounts.cloud.databricks.com"
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_authorization_uri_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_authorization_uri` for Python Connector engine.
"""
from superset.db_engine_specs.base import OAuth2State
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = DatabricksPythonConnectorEngineSpec.get_oauth2_authorization_uri(
oauth2_config_python, state
)
parsed = urlparse(url)
assert parsed.netloc == "accounts.cloud.databricks.com"
assert parsed.path == "/oidc/accounts/12345/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
def test_get_oauth2_token_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token` for Native engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert DatabricksNativeEngineSpec.get_oauth2_token(
oauth2_config_native, "authorization-code"
) == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"code": "authorization-code",
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
def test_get_oauth2_token_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token` for Python Connector engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert DatabricksPythonConnectorEngineSpec.get_oauth2_token(
oauth2_config_python, "authorization-code"
) == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"code": "authorization-code",
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
def test_get_oauth2_fresh_token_native(
mocker: MockerFixture,
oauth2_config_native: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_fresh_token` for Native engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
assert DatabricksNativeEngineSpec.get_oauth2_fresh_token(
oauth2_config_native, "old-refresh-token"
) == {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"refresh_token": "old-refresh-token",
"grant_type": "refresh_token",
},
timeout=30.0,
)
def _oauth2_state() -> OAuth2State:
"""
Build the default OAuth2 state shared by the OAuth2 tests.
"""
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
return state
def _unresolved_oauth2_config() -> OAuth2ClientConfig:
"""
Config as built by `get_oauth2_config` when no endpoints are overridden:
the URIs default to the spec's empty class attributes.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "",
"token_request_uri": "",
"request_content_type": "json",
}
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
@pytest.mark.parametrize(
"host",
[
"dbc-abc.cloud.databricks.com",
"adb-123456789.12.azuredatabricks.net",
"123456789.gcp.databricks.com",
],
)
def test_get_oauth2_authorization_uri_derives_from_workspace_host(
mocker: MockerFixture,
spec: Any,
host: str,
) -> None:
"""
With no configured `authorization_request_uri`, the endpoint is derived from
the workspace host (`https://<host>/oidc/v1/authorize`) on every cloud, with
no account/tenant identifier required.
"""
database = mocker.MagicMock()
database.url_object.host = host
mocker.patch("superset.db.session.get", return_value=database)
url = spec.get_oauth2_authorization_uri(
_unresolved_oauth2_config(), _oauth2_state()
)
parsed = urlparse(url)
assert parsed.netloc == host
assert parsed.path == "/oidc/v1/authorize"
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_authorization_uri_preserves_configured(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
A fully-resolved `authorization_request_uri` is never overwritten by the
host-derived endpoint, and no database lookup is needed.
"""
session_get = mocker.patch("superset.db.session.get")
config = _unresolved_oauth2_config()
config["authorization_request_uri"] = (
"https://accounts.cloud.databricks.com/oidc/accounts/override/v1/authorize"
)
url = spec.get_oauth2_authorization_uri(config, _oauth2_state())
assert urlparse(url).path == "/oidc/accounts/override/v1/authorize"
session_get.assert_not_called()
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_authorization_uri_fails_without_host(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
When the endpoint must be derived but the connection has no host, fail fast
instead of emitting an invalid `https:///oidc/v1/authorize` URL.
"""
database = mocker.MagicMock()
database.url_object.host = None
mocker.patch("superset.db.session.get", return_value=database)
with pytest.raises(OAuth2Error):
spec.get_oauth2_authorization_uri(_unresolved_oauth2_config(), _oauth2_state())
@pytest.mark.parametrize(
"spec",
[DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec],
)
def test_get_oauth2_token_fails_without_uri(
mocker: MockerFixture,
spec: Any,
) -> None:
"""
Token exchange has no database context to auto-detect the endpoint, so a
missing `token_request_uri` fails fast rather than POSTing to `.../{}/...`.
"""
with pytest.raises(OAuth2Error):
spec.get_oauth2_token(_unresolved_oauth2_config(), "authorization-code")
def test_get_oauth2_fresh_token_python(
mocker: MockerFixture,
oauth2_config_python: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_fresh_token` for Python Connector engine.
"""
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
assert DatabricksPythonConnectorEngineSpec.get_oauth2_fresh_token(
oauth2_config_python, "old-refresh-token"
) == {
"access_token": "new-access-token",
"expires_in": 3600,
"scope": "sql",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
}
requests.post.assert_called_with(
"https://accounts.cloud.databricks.com/oidc/accounts/12345/v1/token",
json={
"client_id": "databricks-client-id",
"client_secret": "databricks-client-secret",
"refresh_token": "old-refresh-token",
"grant_type": "refresh_token",
},
timeout=30.0,
)

View File

@@ -1,127 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from typing import Any
from unittest.mock import MagicMock
from urllib.parse import parse_qs, urlparse
import pytest
from pytest_mock import MockerFixture
from superset.db_engine_specs.databricks import (
DatabricksNativeEngineSpec,
DatabricksPythonConnectorEngineSpec,
)
from superset.superset_typing import OAuth2ClientConfig
from superset.utils.oauth2 import decode_oauth2_state
# Multi-Cloud Provider Tests
#
# Databricks fronts the user-to-machine OAuth2 flow on every workspace at
# `https://<workspace-host>/oidc/v1/{authorize,token}`, regardless of cloud, so
# the authorization endpoint derives from the connection host with no per-cloud
# account/tenant identifier.
SPECS = [DatabricksNativeEngineSpec, DatabricksPythonConnectorEngineSpec]
# Representative workspace hosts for each cloud provider.
CLOUD_HOSTS = [
"my-cluster.cloud.databricks.com", # AWS
"adb-123456789.12.azuredatabricks.net", # Azure
"123456789.gcp.databricks.com", # GCP
]
@pytest.fixture
def oauth2_config_no_uri() -> OAuth2ClientConfig:
"""
Config for Databricks OAuth2 without a pre-configured endpoint, so the
authorization endpoint is derived from the workspace host.
"""
return {
"id": "databricks-client-id",
"secret": "databricks-client-secret",
"scope": "sql",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "",
"token_request_uri": "",
"request_content_type": "json",
}
def _mock_database(mocker: MockerFixture, host: str) -> MagicMock:
"""
Build a mock database whose URL resolves to the given workspace host.
"""
database = mocker.MagicMock()
database.url_object.host = host
database.id = 1
return database
@pytest.mark.parametrize("spec", SPECS)
@pytest.mark.parametrize("host", CLOUD_HOSTS)
def test_get_oauth2_authorization_uri_uses_workspace_host(
mocker: MockerFixture,
spec: Any,
host: str,
oauth2_config_no_uri: OAuth2ClientConfig,
) -> None:
"""
The authorization endpoint is the workspace host on AWS, Azure, and GCP.
"""
from superset.db_engine_specs.base import OAuth2State
mocker.patch(
"superset.extensions.db.session.get",
return_value=_mock_database(mocker, host),
)
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"tab_id": "1234",
}
url = spec.get_oauth2_authorization_uri(oauth2_config_no_uri, state)
parsed = urlparse(url)
assert parsed.netloc == host
assert parsed.path == "/oidc/v1/authorize"
query = parse_qs(parsed.query)
assert query["scope"][0] == "sql"
encoded_state = query["state"][0].replace("%2E", ".")
assert decode_oauth2_state(encoded_state) == state
@pytest.mark.parametrize("spec", SPECS)
@pytest.mark.parametrize("host", CLOUD_HOSTS)
def test_workspace_oauth2_endpoint_builds_token_uri(
mocker: MockerFixture,
spec: Any,
host: str,
) -> None:
"""
The helper builds the matching token endpoint from the same workspace host.
"""
database = _mock_database(mocker, host)
assert (
spec._workspace_oauth2_endpoint(database, "token")
== f"https://{host}/oidc/v1/token"
)

View File

@@ -96,7 +96,6 @@ async def test_list_dashboards_basic(mock_list, mcp_server):
dashboard.uuid = "test-dashboard-uuid-1"
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -163,7 +162,6 @@ async def test_list_dashboards_with_filters(mock_list, mcp_server):
dashboard.uuid = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -259,7 +257,6 @@ async def test_list_dashboards_with_search(mock_list, mcp_server):
dashboard.uuid = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -354,7 +351,6 @@ async def test_get_dashboard_info_success(
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -433,7 +429,6 @@ async def test_get_dashboard_info_permalink_does_not_double_sanitize(
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
mock_info.return_value = dashboard
permalink_value = {
@@ -526,7 +521,6 @@ async def test_get_dashboard_info_permalink_key_includes_filter_state(
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
mock_info.return_value = dashboard
@@ -773,7 +767,6 @@ async def test_get_dashboard_info_does_not_expose_access_list_or_roles(
dashboard.owners = [owner]
dashboard.tags = []
dashboard.roles = [dashboard_role]
dashboard.embedded = []
mock_info.return_value = dashboard
@@ -845,7 +838,6 @@ async def test_get_dashboard_info_restricted_user_redacts_data_model_metadata(
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
mock_info.return_value = dashboard
@@ -898,7 +890,6 @@ async def test_get_dashboard_info_restricted_user_redacts_permalink_filter_state
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
mock_info.return_value = dashboard
@@ -1021,88 +1012,6 @@ async def test_list_dashboards_omits_requested_user_directory_fields(
assert field not in data["columns_available"]
@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object")
@pytest.mark.asyncio
async def test_get_dashboard_info_includes_embedded_uuid(mock_find_object, mcp_server):
"""Test that get_dashboard_info returns embedded_uuid when set."""
from superset.models.embedded_dashboard import EmbeddedDashboard
dashboard = Mock()
dashboard.id = 1
dashboard.dashboard_title = "Embedded Dashboard"
dashboard.slug = ""
dashboard.description = None
dashboard.css = None
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = "{}"
dashboard.published = True
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.created_on = None
dashboard.changed_on = None
dashboard.created_on_humanized = None
dashboard.changed_on_humanized = None
dashboard.uuid = "94b826a5-dbd5-473d-ab58-1af676ee07e4"
dashboard.url = "/dashboard/1"
dashboard.slices = []
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
embedded = Mock(spec=EmbeddedDashboard)
embedded.uuid = "37c56048-d3f1-452d-b3ae-0879802dcb1f"
dashboard.embedded = [embedded]
mock_find_object.return_value = dashboard
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_info", {"request": {"identifier": 1}}
)
assert result.data["uuid"] == "94b826a5-dbd5-473d-ab58-1af676ee07e4"
assert result.data["embedded_uuid"] == "37c56048-d3f1-452d-b3ae-0879802dcb1f"
@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object")
@pytest.mark.asyncio
async def test_get_dashboard_info_embedded_uuid_none_when_not_embedded(
mock_find_object, mcp_server
):
"""Test that embedded_uuid is None when the dashboard has not been configured
for embedding."""
dashboard = Mock()
dashboard.id = 2
dashboard.dashboard_title = "Non-Embedded Dashboard"
dashboard.slug = ""
dashboard.description = None
dashboard.css = None
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = "{}"
dashboard.published = True
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.created_on = None
dashboard.changed_on = None
dashboard.created_on_humanized = None
dashboard.changed_on_humanized = None
dashboard.uuid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
dashboard.url = "/dashboard/2"
dashboard.slices = []
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
mock_find_object.return_value = dashboard
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_info", {"request": {"identifier": 2}}
)
assert result.data.get("embedded_uuid") is None
# TODO (Phase 3+): Add tests for get_dashboard_available_filters tool
@@ -1135,7 +1044,6 @@ async def test_get_dashboard_info_by_uuid(mock_find_object, mcp_server):
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
mock_find_object.return_value = dashboard
async with Client(mcp_server) as client:
@@ -1175,7 +1083,6 @@ async def test_get_dashboard_info_by_slug(mock_find_object, mcp_server):
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.embedded = []
mock_find_object.return_value = dashboard
async with Client(mcp_server) as client:
@@ -1215,7 +1122,6 @@ async def test_list_dashboards_custom_uuid_slug_columns(mock_list, mcp_server):
dashboard.external_url = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -1297,7 +1203,6 @@ async def test_list_dashboards_sanitizes_dashboard_descriptions_and_filter_text(
dashboard.external_url = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
@@ -1438,7 +1343,6 @@ class TestDashboardDefaultColumnFiltering:
dashboard.external_url = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.embedded = []
dashboard.charts = []
mock_list.return_value = ([dashboard], 1)