Compare commits

..

3 Commits

Author SHA1 Message Date
Amin Ghadersohi
16d136ef8d fix(mcp): context-aware recovery hints and sanitize identifier in not-found errors
- In get_chart_preview, when identifier looks like a form_data_key (long
  non-numeric string), suggest regenerating the explore link rather than
  always pointing to list_charts, which is only relevant for chart IDs.
- Truncate request.identifier to 200 chars before embedding in error
  messages across get_chart_preview, get_chart_data, and update_chart
  to prevent injection via oversized attacker-controlled identifiers.
2026-05-09 00:26:11 +00:00
Amin Ghadersohi
c78658d852 fix(mcp): improve "not found" errors to suggest corresponding list_* tools
When MCP tools return "not found" errors for database, chart, dataset, or
dashboard IDs, include recovery guidance pointing to the appropriate list
tool (list_databases, list_charts, list_datasets, list_dashboards).

Affected tools: execute_sql, open_sql_lab_with_context, query_dataset,
get_chart_data, get_chart_preview, update_chart,
add_chart_to_existing_dashboard, generate_dashboard
2026-05-06 22:55:46 +00:00
bdonovan1
5b5dd01028 fix(sqla): parenthesize calculated column expressions in WHERE clause (#39793)
Co-authored-by: Brian Donovan <briand@netflix.com>
Co-authored-by: Vitor Avila <96086495+Vitor-Avila@users.noreply.github.com>
2026-05-06 19:45:27 -03:00
20 changed files with 300 additions and 188 deletions

View File

@@ -67,7 +67,7 @@ export const renameOperator: PostProcessingFactory<PostProcessingRename> = (
[...metricOffsetMap.entries()].forEach(
([metricWithOffset, metricOnly]) => {
const offsetLabel = timeOffsets.find(offset =>
metricWithOffset.endsWith(`${TIME_COMPARISON_SEPARATOR}${offset}`),
metricWithOffset.includes(offset),
);
renamePairs.push([
formData.comparison_type === ComparisonType.Values

View File

@@ -26,7 +26,7 @@ export const getTimeOffset = (
timeCompare.find(
timeOffset =>
// offset is represented as <offset>, group by list
series.name.startsWith(`${timeOffset},`) ||
series.name.includes(`${timeOffset},`) ||
// offset is represented as <metric>__<offset>
series.name.includes(`__${timeOffset}`) ||
// offset is represented as <metric>, <offset>
@@ -50,9 +50,7 @@ export const getOriginalSeries = (
// offset in the middle: <metric>, <offset>, <dimension>
result = result.replace(`, ${compare},`, ',');
// offset at start: <offset>, <dimension>
if (result.startsWith(`${compare},`)) {
result = result.slice(`${compare},`.length);
}
result = result.replace(`${compare},`, '');
// offset with double underscore: <metric>__<offset>
result = result.replace(`__${compare}`, '');
// offset at end: <metric>, <offset>

View File

@@ -303,30 +303,6 @@ test('should add renameOperator if multiple metrics exist', () => {
});
});
test('should correctly match offsets that share a numeric prefix', () => {
expect(
renameOperator(
{
...formData,
comparison_type: ComparisonType.Values,
time_compare: ['1 year ago', '11 year ago'],
},
queryObject,
),
).toEqual({
operation: 'rename',
options: {
columns: {
'count(*)__1 year ago': '1 year ago',
'count(*)__11 year ago': '11 year ago',
},
inplace: true,
level: 0,
},
});
});
test('should remove renameOperator', () => {
expect(
renameOperator(

View File

@@ -114,26 +114,3 @@ test('hasTimeOffset returns false when series name is not a string', () => {
const timeCompare = ['1 year ago'];
expect(hasTimeOffset(series, timeCompare)).toBe(false);
});
test('getTimeOffset correctly matches offsets that share a numeric prefix', () => {
const timeCompare = ['1 year ago', '11 year ago'];
expect(
getTimeOffset({ name: '11 year ago, Alexander' }, timeCompare),
).toEqual('11 year ago');
expect(getTimeOffset({ name: '1 year ago, Alexander' }, timeCompare)).toEqual(
'1 year ago',
);
expect(getTimeOffset({ name: 'Births__11 year ago' }, timeCompare)).toEqual(
'11 year ago',
);
});
test('getOriginalSeries correctly strips offsets that share a numeric prefix', () => {
const timeCompare = ['1 year ago', '11 year ago'];
expect(getOriginalSeries('11 year ago, Alexander', timeCompare)).toEqual(
'Alexander',
);
expect(getOriginalSeries('1 year ago, Alexander', timeCompare)).toEqual(
'Alexander',
);
});

View File

@@ -199,8 +199,12 @@ async def get_chart_data( # noqa: C901
if not chart:
await ctx.warning("Chart not found: identifier=%s" % (request.identifier,))
safe_id = str(request.identifier)[:200]
return ChartError(
error=f"No chart found with identifier: {request.identifier}",
error=(
f"No chart found with identifier: {safe_id}."
" Use list_charts to get valid chart IDs."
),
error_type="NotFound",
)

View File

@@ -1192,8 +1192,22 @@ async def _get_chart_preview_internal( # noqa: C901
if not chart:
await ctx.warning("Chart not found: identifier=%s" % (request.identifier,))
safe_id = str(request.identifier)[:200]
is_form_data_key = (
isinstance(request.identifier, str)
and len(request.identifier) > 8
and not request.identifier.isdigit()
)
if is_form_data_key:
recovery = (
"If using a form_data_key, it may have expired — "
"use generate_explore_link to get a fresh key, "
"or use list_charts to find a saved chart by ID."
)
else:
recovery = "Use list_charts to get valid chart IDs."
return ChartError(
error=f"No chart found with identifier: {request.identifier}",
error=f"No chart found with identifier: {safe_id}. {recovery}",
error_type="NotFound",
)

View File

@@ -337,17 +337,18 @@ async def update_chart( # noqa: C901
chart = find_chart_by_identifier(request.identifier)
if not chart:
safe_id = str(request.identifier)[:200]
not_found_msg = (
f"No chart found with identifier: {safe_id}."
" Use list_charts to get valid chart IDs."
)
return GenerateChartResponse.model_validate(
{
"chart": None,
"error": {
"error_type": "NotFound",
"message": (
f"No chart found with identifier: {request.identifier}"
),
"details": (
f"No chart found with identifier: {request.identifier}"
),
"message": not_found_msg,
"details": not_found_msg,
},
"success": False,
"schema_version": "2.0",

View File

@@ -334,7 +334,10 @@ def _find_and_authorize_dashboard(
dashboard=None,
dashboard_url=None,
position=None,
error=f"Dashboard with ID {dashboard_id} not found",
error=(
f"Dashboard with ID {dashboard_id} not found."
" Use list_dashboards to get valid dashboard IDs."
),
)
try:
@@ -392,7 +395,10 @@ def add_chart_to_existing_dashboard(
dashboard=None,
dashboard_url=None,
position=None,
error=f"Chart with ID {request.chart_id} not found",
error=(
f"Chart with ID {request.chart_id} not found."
" Use list_charts to get valid chart IDs."
),
)
# Validate dataset access for the chart.

View File

@@ -230,7 +230,10 @@ def generate_dashboard( # noqa: C901
return GenerateDashboardResponse(
dashboard=None,
dashboard_url=None,
error=f"Charts not found: {list(missing_chart_ids)}",
error=(
f"Charts not found: {list(missing_chart_ids)}."
" Use list_charts to get valid chart IDs."
),
)
# Validate dataset access for each chart.

View File

@@ -183,7 +183,10 @@ async def query_dataset( # noqa: C901
if dataset is None:
await ctx.error("Dataset not found: identifier=%s" % (request.dataset_id,))
return DatasetError.create(
error=f"No dataset found with identifier: {request.dataset_id}",
error=(
f"No dataset found with identifier: {request.dataset_id}."
" Use list_datasets to get valid dataset IDs."
),
error_type="NotFound",
)

View File

@@ -100,7 +100,10 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes
)
return ExecuteSqlResponse(
success=False,
error=f"Database with ID {request.database_id} not found",
error=(
f"Database with ID {request.database_id} not found."
" Use list_databases to get valid database IDs."
),
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR.value,
)

View File

@@ -103,7 +103,8 @@ def open_sql_lab_with_context(
database = DatabaseDAO.find_by_id(request.database_connection_id)
if not database:
error_message = (
f"Database with ID {request.database_connection_id} not found"
f"Database with ID {request.database_connection_id} not found."
" Use list_databases to get valid database IDs."
)
return _sanitize_sql_lab_response_for_llm_context(
SqlLabResponse(

View File

@@ -1887,26 +1887,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
time_grain
)
if not time_grain:
has_temporal_join_key = any(
pd.api.types.is_datetime64_any_dtype(df[key])
for key in join_keys
if key in df.columns
if join_column_producer and not time_grain:
raise QueryObjectValidationError(
_("Time Grain must be specified when using Time Shift.")
)
if has_temporal_join_key:
has_relative_offset = any(
not (
self.is_valid_date_range(offset)
and feature_flag_manager.is_feature_enabled(
"DATE_RANGE_TIMESHIFTS_ENABLED"
)
)
for offset in offset_dfs
)
if has_relative_offset:
raise QueryObjectValidationError(
_("Time Grain must be specified when using Time Comparison.")
)
for offset, offset_df in offset_dfs.items():
is_date_range_offset = self.is_valid_date_range(
@@ -3100,6 +3084,14 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
sqla_col = self.convert_tbl_column_to_sqla_col(
tbl_column=col_obj, template_processor=template_processor
)
# Parenthesize expression-based columns to prevent operator
# precedence issues (e.g. OR in a calculated column breaking
# surrounding AND filters). Same pattern as extras.where
# wrapping added in PR #38183.
if sqla_col is not None and (
(col_obj and col_obj.expression) or is_adhoc_column(flt_col)
):
sqla_col = Grouping(sqla_col)
col_type = col_obj.type if col_obj else None
col_spec = db_engine_spec.get_column_spec(native_type=col_type)
is_list_target = op in (

View File

@@ -1782,9 +1782,9 @@ def extract_dataframe_dtypes(
columns_by_name[column.column_name] = column
generic_types: list[GenericDataType] = []
for i, column in enumerate(df.columns):
for column in df.columns:
column_object = columns_by_name.get(str(column))
series = df.iloc[:, i]
series = df[column]
inferred_type: str = ""
if series.isna().all():
sql_type: Optional[str] = ""

View File

@@ -24,7 +24,7 @@ from flask_babel import gettext as _
from pandas import DataFrame, MultiIndex
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import PostProcessingContributionOrientation, TIME_COMPARISON
from superset.utils.core import PostProcessingContributionOrientation
from superset.utils.pandas_postprocessing.utils import validate_column_args
@@ -130,7 +130,7 @@ def get_column_groups(
time_shift = None
if time_shifts and isinstance(col_0, str):
for ts in time_shifts:
if col_0.endswith(TIME_COMPARISON + ts):
if col_0.endswith(ts):
time_shift = ts
break
if time_shift is not None:

View File

@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pandas import DataFrame, Series, Timestamp
from pandas.testing import assert_frame_equal
from pytest import fixture, mark # noqa: PT013
@@ -24,7 +23,6 @@ from superset.common.query_context import QueryContext
from superset.common.query_context_processor import QueryContextProcessor
from superset.connectors.sqla.models import BaseDatasource
from superset.constants import TimeGrain
from superset.exceptions import QueryObjectValidationError
from superset.models.helpers import ExploreMixin
# Create processor and bind ExploreMixin methods to datasource
@@ -244,55 +242,3 @@ def test_join_offset_dfs_totals_query_no_dimensions():
)
assert_frame_equal(expected, result)
def test_join_offset_dfs_raises_without_time_grain():
"""Time comparison with relative offsets requires a time grain."""
df = DataFrame({"ds": [Timestamp("2021-01-01")], "D": [1]})
offset_df = DataFrame({"ds": [Timestamp("2021-02-01")], "B": [5]})
offset_dfs = {"1 year ago": offset_df}
with pytest.raises(
QueryObjectValidationError, match="Time Grain must be specified"
):
query_context_processor.join_offset_dfs(
df, offset_dfs, time_grain=None, join_keys=["ds"]
)
def test_join_offset_dfs_allows_non_temporal_join_without_time_grain():
"""Time comparison without time grain is valid when join keys are non-temporal."""
df = DataFrame({"country": ["US", "UK"], "metric": [10, 20]})
offset_df = DataFrame({"country": ["US", "UK"], "metric__1 year ago": [8, 15]})
offset_dfs = {"1 year ago": offset_df}
result = query_context_processor.join_offset_dfs(
df, offset_dfs, time_grain=None, join_keys=["country"]
)
assert "metric__1 year ago" in result.columns
def test_join_offset_dfs_raises_when_temporal_key_not_first():
"""Temporal join key detection works even when it's not the first key."""
df = DataFrame(
{
"country": ["US", "UK"],
"ds": [Timestamp("2021-01-01"), Timestamp("2021-02-01")],
"D": [1, 2],
}
)
offset_df = DataFrame(
{
"country": ["US", "UK"],
"ds": [Timestamp("2021-03-01"), Timestamp("2021-04-01")],
"B": [5, 6],
}
)
offset_dfs = {"1 year ago": offset_df}
with pytest.raises(
QueryObjectValidationError, match="Time Grain must be specified"
):
query_context_processor.join_offset_dfs(
df, offset_dfs, time_grain=None, join_keys=["country", "ds"]
)

View File

@@ -298,7 +298,8 @@ class TestOpenSqlLabWithContext:
field_path=("title",),
)
assert response.error == sanitize_for_llm_context(
"Database with ID 404 not found",
"Database with ID 404 not found."
" Use list_databases to get valid database IDs.",
field_path=("error",),
)
finally:

View File

@@ -1937,6 +1937,235 @@ def test_extras_having_is_parenthesized(
)
def test_calculated_column_filter_is_parenthesized(
database: Database,
) -> None:
"""
Test that calculated column expressions containing OR are wrapped in
parentheses when used in WHERE filters.
Without parentheses, a calculated column expression like
``status = 'active' OR status = 'pending'`` combined with other filters
via AND would produce unexpected evaluation order due to SQL operator
precedence (AND binds tighter than OR), potentially dropping time range
and other filters. Same class of bug as fixed in PR #38183 for
extras.where/having, but on the calculated column filter path.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(
column_name="is_active",
expression="status = 'active' OR status = 'pending'",
type="BOOLEAN",
),
],
)
sqla_query = table.get_sqla_query(
columns=["a"],
filter=[
{
"col": "is_active",
"op": "IS TRUE",
"val": None,
},
],
extras={},
is_timeseries=False,
metrics=[],
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
assert "(status = 'active' OR status = 'pending')" in sql, (
f"Calculated column expression should be wrapped in parentheses. "
f"Generated SQL: {sql}"
)
def test_calculated_column_nested_or_and_is_parenthesized(
database: Database,
) -> None:
"""
Test that calculated column expressions with nested OR/AND combinations
are correctly parenthesized as a single unit in WHERE filters.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(
column_name="is_target",
expression=(
"(status = 'active' AND region = 'US') "
"OR (status = 'pending' AND region = 'EU')"
),
type="BOOLEAN",
),
],
)
sqla_query = table.get_sqla_query(
columns=["a"],
filter=[
{
"col": "is_target",
"op": "IS TRUE",
"val": None,
},
],
extras={},
is_timeseries=False,
metrics=[],
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
assert (
"((status = 'active' AND region = 'US') "
"OR (status = 'pending' AND region = 'EU'))"
) in sql, (
f"Nested OR/AND expression should be wrapped in parentheses. "
f"Generated SQL: {sql}"
)
def test_calculated_column_non_boolean_filter_is_parenthesized(
database: Database,
) -> None:
"""
Test that non-boolean calculated column expressions are parenthesized
when used with IN filters.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(
column_name="full_name",
expression="first_name || ' ' || last_name",
type="TEXT",
),
],
)
sqla_query = table.get_sqla_query(
columns=["a"],
filter=[
{
"col": "full_name",
"op": "IN",
"val": ["John Doe", "Jane Doe"],
},
],
extras={},
is_timeseries=False,
metrics=[],
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
assert "(first_name || ' ' || last_name)" in sql, (
f"Non-boolean calculated column should be wrapped in parentheses. "
f"Generated SQL: {sql}"
)
def test_multiple_calculated_columns_each_parenthesized(
database: Database,
) -> None:
"""
Test that multiple calculated columns used as filters are each
independently wrapped in parentheses.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(
column_name="is_active",
expression="status = 'active' OR status = 'pending'",
type="BOOLEAN",
),
TableColumn(
column_name="is_premium",
expression="tier = 'gold' OR tier = 'platinum'",
type="BOOLEAN",
),
],
)
sqla_query = table.get_sqla_query(
columns=["a"],
filter=[
{
"col": "is_active",
"op": "IS TRUE",
"val": None,
},
{
"col": "is_premium",
"op": "IS TRUE",
"val": None,
},
],
extras={},
is_timeseries=False,
metrics=[],
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
assert "(status = 'active' OR status = 'pending')" in sql, (
f"First calculated column should be parenthesized. Generated SQL: {sql}"
)
assert "(tier = 'gold' OR tier = 'platinum')" in sql, (
f"Second calculated column should be parenthesized. Generated SQL: {sql}"
)
def _run_probe(
database: Database,
type_probe_needs_row: bool = False,

View File

@@ -124,36 +124,3 @@ def test_contribution_with_time_shift_columns():
assert_array_equal(processed_df["a__1 week ago"].tolist(), [0.5, 0.5])
assert_array_equal(processed_df["b__1 week ago"].tolist(), [0.25, 0.25])
assert_array_equal(processed_df["c__1 week ago"].tolist(), [0.25, 0.25])
def test_contribution_with_numeric_prefix_time_shifts():
"""Time shifts like '2 weeks ago' and '22 weeks ago' share a numeric suffix;
columns must be grouped by their exact offset, not by suffix matching."""
df = DataFrame(
{
DTTM_ALIAS: [
datetime(2020, 7, 16, 14, 49),
datetime(2020, 7, 16, 14, 50),
],
"a": [3, 6],
"b": [6, 3],
"a__2 weeks ago": [1, 1],
"b__2 weeks ago": [1, 1],
"a__22 weeks ago": [2, 4],
"b__22 weeks ago": [4, 2],
}
)
processed_df = contribution(
df,
orientation=PostProcessingContributionOrientation.ROW,
time_shifts=["2 weeks ago", "22 weeks ago"],
)
# Non-time-shift columns: a=3,b=6 -> a=1/3, b=2/3; a=6,b=3 -> a=2/3, b=1/3
assert_array_equal(processed_df["a"].tolist(), [1 / 3, 2 / 3])
assert_array_equal(processed_df["b"].tolist(), [2 / 3, 1 / 3])
# "2 weeks ago" group: a=1,b=1 -> 0.5,0.5 each row
assert_array_equal(processed_df["a__2 weeks ago"].tolist(), [0.5, 0.5])
assert_array_equal(processed_df["b__2 weeks ago"].tolist(), [0.5, 0.5])
# "22 weeks ago" group: a=2,b=4 -> 1/3,2/3; a=4,b=2 -> 2/3,1/3
assert_array_equal(processed_df["a__22 weeks ago"].tolist(), [1 / 3, 2 / 3])
assert_array_equal(processed_df["b__22 weeks ago"].tolist(), [2 / 3, 1 / 3])

View File

@@ -31,7 +31,6 @@ from superset.utils.core import (
cast_to_boolean,
check_is_safe_zip,
DateColumn,
extract_dataframe_dtypes,
FilterOperator,
generic_find_constraint_name,
generic_find_fk_constraint_name,
@@ -648,9 +647,8 @@ def test_get_user_agent(mocker: MockerFixture, app_context: None) -> None:
@with_config(
{
"USER_AGENT_FUNC": lambda database, source: (
f"{database.database_name} {source.name}"
)
"USER_AGENT_FUNC": lambda database,
source: f"{database.database_name} {source.name}"
}
)
def test_get_user_agent_custom(mocker: MockerFixture, app_context: None) -> None:
@@ -1732,10 +1730,3 @@ def test_markdown_with_markup_wrap() -> None:
assert isinstance(result, Markup)
assert "<strong>bold</strong>" in str(result)
def test_extract_dataframe_dtypes_with_duplicate_columns() -> None:
"""extract_dataframe_dtypes should not crash on duplicate column names."""
df = pd.DataFrame([[1, 2, 3]], columns=["a", "b", "a"])
result = extract_dataframe_dtypes(df)
assert len(result) == 3