Compare commits

...

3 Commits

Author SHA1 Message Date
Amin Ghadersohi
16d136ef8d fix(mcp): context-aware recovery hints and sanitize identifier in not-found errors
- In get_chart_preview, when identifier looks like a form_data_key (long
  non-numeric string), suggest regenerating the explore link rather than
  always pointing to list_charts, which is only relevant for chart IDs.
- Truncate request.identifier to 200 chars before embedding in error
  messages across get_chart_preview, get_chart_data, and update_chart
  to prevent injection via oversized attacker-controlled identifiers.
2026-05-09 00:26:11 +00:00
Amin Ghadersohi
c78658d852 fix(mcp): improve "not found" errors to suggest corresponding list_* tools
When MCP tools return "not found" errors for database, chart, dataset, or
dashboard IDs, include recovery guidance pointing to the appropriate list
tool (list_databases, list_charts, list_datasets, list_dashboards).

Affected tools: execute_sql, open_sql_lab_with_context, query_dataset,
get_chart_data, get_chart_preview, update_chart,
add_chart_to_existing_dashboard, generate_dashboard
2026-05-06 22:55:46 +00:00
bdonovan1
5b5dd01028 fix(sqla): parenthesize calculated column expressions in WHERE clause (#39793)
Co-authored-by: Brian Donovan <briand@netflix.com>
Co-authored-by: Vitor Avila <96086495+Vitor-Avila@users.noreply.github.com>
2026-05-06 19:45:27 -03:00
11 changed files with 288 additions and 15 deletions

View File

@@ -199,8 +199,12 @@ async def get_chart_data( # noqa: C901
if not chart:
await ctx.warning("Chart not found: identifier=%s" % (request.identifier,))
safe_id = str(request.identifier)[:200]
return ChartError(
error=f"No chart found with identifier: {request.identifier}",
error=(
f"No chart found with identifier: {safe_id}."
" Use list_charts to get valid chart IDs."
),
error_type="NotFound",
)

View File

@@ -1192,8 +1192,22 @@ async def _get_chart_preview_internal( # noqa: C901
if not chart:
await ctx.warning("Chart not found: identifier=%s" % (request.identifier,))
safe_id = str(request.identifier)[:200]
is_form_data_key = (
isinstance(request.identifier, str)
and len(request.identifier) > 8
and not request.identifier.isdigit()
)
if is_form_data_key:
recovery = (
"If using a form_data_key, it may have expired — "
"use generate_explore_link to get a fresh key, "
"or use list_charts to find a saved chart by ID."
)
else:
recovery = "Use list_charts to get valid chart IDs."
return ChartError(
error=f"No chart found with identifier: {request.identifier}",
error=f"No chart found with identifier: {safe_id}. {recovery}",
error_type="NotFound",
)

View File

@@ -337,17 +337,18 @@ async def update_chart( # noqa: C901
chart = find_chart_by_identifier(request.identifier)
if not chart:
safe_id = str(request.identifier)[:200]
not_found_msg = (
f"No chart found with identifier: {safe_id}."
" Use list_charts to get valid chart IDs."
)
return GenerateChartResponse.model_validate(
{
"chart": None,
"error": {
"error_type": "NotFound",
"message": (
f"No chart found with identifier: {request.identifier}"
),
"details": (
f"No chart found with identifier: {request.identifier}"
),
"message": not_found_msg,
"details": not_found_msg,
},
"success": False,
"schema_version": "2.0",

View File

@@ -334,7 +334,10 @@ def _find_and_authorize_dashboard(
dashboard=None,
dashboard_url=None,
position=None,
error=f"Dashboard with ID {dashboard_id} not found",
error=(
f"Dashboard with ID {dashboard_id} not found."
" Use list_dashboards to get valid dashboard IDs."
),
)
try:
@@ -392,7 +395,10 @@ def add_chart_to_existing_dashboard(
dashboard=None,
dashboard_url=None,
position=None,
error=f"Chart with ID {request.chart_id} not found",
error=(
f"Chart with ID {request.chart_id} not found."
" Use list_charts to get valid chart IDs."
),
)
# Validate dataset access for the chart.

View File

@@ -230,7 +230,10 @@ def generate_dashboard( # noqa: C901
return GenerateDashboardResponse(
dashboard=None,
dashboard_url=None,
error=f"Charts not found: {list(missing_chart_ids)}",
error=(
f"Charts not found: {list(missing_chart_ids)}."
" Use list_charts to get valid chart IDs."
),
)
# Validate dataset access for each chart.

View File

@@ -183,7 +183,10 @@ async def query_dataset( # noqa: C901
if dataset is None:
await ctx.error("Dataset not found: identifier=%s" % (request.dataset_id,))
return DatasetError.create(
error=f"No dataset found with identifier: {request.dataset_id}",
error=(
f"No dataset found with identifier: {request.dataset_id}."
" Use list_datasets to get valid dataset IDs."
),
error_type="NotFound",
)

View File

@@ -100,7 +100,10 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes
)
return ExecuteSqlResponse(
success=False,
error=f"Database with ID {request.database_id} not found",
error=(
f"Database with ID {request.database_id} not found."
" Use list_databases to get valid database IDs."
),
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR.value,
)

View File

@@ -103,7 +103,8 @@ def open_sql_lab_with_context(
database = DatabaseDAO.find_by_id(request.database_connection_id)
if not database:
error_message = (
f"Database with ID {request.database_connection_id} not found"
f"Database with ID {request.database_connection_id} not found."
" Use list_databases to get valid database IDs."
)
return _sanitize_sql_lab_response_for_llm_context(
SqlLabResponse(

View File

@@ -3084,6 +3084,14 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
sqla_col = self.convert_tbl_column_to_sqla_col(
tbl_column=col_obj, template_processor=template_processor
)
# Parenthesize expression-based columns to prevent operator
# precedence issues (e.g. OR in a calculated column breaking
# surrounding AND filters). Same pattern as extras.where
# wrapping added in PR #38183.
if sqla_col is not None and (
(col_obj and col_obj.expression) or is_adhoc_column(flt_col)
):
sqla_col = Grouping(sqla_col)
col_type = col_obj.type if col_obj else None
col_spec = db_engine_spec.get_column_spec(native_type=col_type)
is_list_target = op in (

View File

@@ -298,7 +298,8 @@ class TestOpenSqlLabWithContext:
field_path=("title",),
)
assert response.error == sanitize_for_llm_context(
"Database with ID 404 not found",
"Database with ID 404 not found."
" Use list_databases to get valid database IDs.",
field_path=("error",),
)
finally:

View File

@@ -1937,6 +1937,235 @@ def test_extras_having_is_parenthesized(
)
def test_calculated_column_filter_is_parenthesized(
database: Database,
) -> None:
"""
Test that calculated column expressions containing OR are wrapped in
parentheses when used in WHERE filters.
Without parentheses, a calculated column expression like
``status = 'active' OR status = 'pending'`` combined with other filters
via AND would produce unexpected evaluation order due to SQL operator
precedence (AND binds tighter than OR), potentially dropping time range
and other filters. Same class of bug as fixed in PR #38183 for
extras.where/having, but on the calculated column filter path.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(
column_name="is_active",
expression="status = 'active' OR status = 'pending'",
type="BOOLEAN",
),
],
)
sqla_query = table.get_sqla_query(
columns=["a"],
filter=[
{
"col": "is_active",
"op": "IS TRUE",
"val": None,
},
],
extras={},
is_timeseries=False,
metrics=[],
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
assert "(status = 'active' OR status = 'pending')" in sql, (
f"Calculated column expression should be wrapped in parentheses. "
f"Generated SQL: {sql}"
)
def test_calculated_column_nested_or_and_is_parenthesized(
database: Database,
) -> None:
"""
Test that calculated column expressions with nested OR/AND combinations
are correctly parenthesized as a single unit in WHERE filters.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(
column_name="is_target",
expression=(
"(status = 'active' AND region = 'US') "
"OR (status = 'pending' AND region = 'EU')"
),
type="BOOLEAN",
),
],
)
sqla_query = table.get_sqla_query(
columns=["a"],
filter=[
{
"col": "is_target",
"op": "IS TRUE",
"val": None,
},
],
extras={},
is_timeseries=False,
metrics=[],
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
assert (
"((status = 'active' AND region = 'US') "
"OR (status = 'pending' AND region = 'EU'))"
) in sql, (
f"Nested OR/AND expression should be wrapped in parentheses. "
f"Generated SQL: {sql}"
)
def test_calculated_column_non_boolean_filter_is_parenthesized(
database: Database,
) -> None:
"""
Test that non-boolean calculated column expressions are parenthesized
when used with IN filters.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(
column_name="full_name",
expression="first_name || ' ' || last_name",
type="TEXT",
),
],
)
sqla_query = table.get_sqla_query(
columns=["a"],
filter=[
{
"col": "full_name",
"op": "IN",
"val": ["John Doe", "Jane Doe"],
},
],
extras={},
is_timeseries=False,
metrics=[],
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
assert "(first_name || ' ' || last_name)" in sql, (
f"Non-boolean calculated column should be wrapped in parentheses. "
f"Generated SQL: {sql}"
)
def test_multiple_calculated_columns_each_parenthesized(
database: Database,
) -> None:
"""
Test that multiple calculated columns used as filters are each
independently wrapped in parentheses.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(
column_name="is_active",
expression="status = 'active' OR status = 'pending'",
type="BOOLEAN",
),
TableColumn(
column_name="is_premium",
expression="tier = 'gold' OR tier = 'platinum'",
type="BOOLEAN",
),
],
)
sqla_query = table.get_sqla_query(
columns=["a"],
filter=[
{
"col": "is_active",
"op": "IS TRUE",
"val": None,
},
{
"col": "is_premium",
"op": "IS TRUE",
"val": None,
},
],
extras={},
is_timeseries=False,
metrics=[],
)
with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
assert "(status = 'active' OR status = 'pending')" in sql, (
f"First calculated column should be parenthesized. Generated SQL: {sql}"
)
assert "(tier = 'gold' OR tier = 'platinum')" in sql, (
f"Second calculated column should be parenthesized. Generated SQL: {sql}"
)
def _run_probe(
database: Database,
type_probe_needs_row: bool = False,