diff --git a/superset/daos/chart.py b/superset/daos/chart.py index 352b9b9cb68..b9c28938b07 100644 --- a/superset/daos/chart.py +++ b/superset/daos/chart.py @@ -21,13 +21,15 @@ from datetime import datetime from typing import Dict, List from flask_appbuilder.models.sqla.interface import SQLAInterface +from sqlalchemy import or_, select +from sqlalchemy.orm import Query from superset.charts.filters import ChartFilter from superset.commands.chart.exceptions import ChartNotFoundError -from superset.daos.base import BaseDAO +from superset.daos.base import BaseDAO, ColumnOperator, ColumnOperatorEnum from superset.extensions import db from superset.models.core import FavStar, FavStarClassName -from superset.models.slice import id_or_uuid_filter, Slice +from superset.models.slice import id_or_uuid_filter, Slice, slice_user from superset.utils.core import get_user_id logger = logging.getLogger(__name__) @@ -36,12 +38,56 @@ logger = logging.getLogger(__name__) CHART_CUSTOM_FIELDS = { "viz_type": ["eq", "in", "like"], "datasource_name": ["eq", "in", "like"], + "owner": ["eq", "in"], } class ChartDAO(BaseDAO[Slice]): base_filter = ChartFilter + @classmethod + def apply_column_operators( + cls, + query: Query, + column_operators: list[ColumnOperator] | None = None, + ) -> Query: + """Override to handle owner filter via the slice_user M2M table.""" + if not column_operators: + return query + + remaining_operators: list[ColumnOperator] = [] + for c in column_operators: + if not isinstance(c, ColumnOperator): + c = ColumnOperator.model_validate(c) + if c.col == "owner": + operator_enum = ColumnOperatorEnum(c.opr) + subq = select(slice_user.c.slice_id).where( + operator_enum.apply(slice_user.c.user_id, c.value) + ) + query = query.filter( + Slice.id.in_(subq) # type: ignore[attr-defined,unused-ignore] + ) + elif c.col == "created_by_fk_or_owner": + if c.opr != "eq": + raise ValueError( + f"created_by_fk_or_owner only supports 'eq'; got '{c.opr}'" + ) + owner_subq = select(slice_user.c.slice_id).where( + slice_user.c.user_id == c.value + ) + query = query.filter( + or_( + Slice.created_by_fk == c.value, # type: ignore[attr-defined,unused-ignore] + Slice.id.in_(owner_subq), # type: ignore[attr-defined,unused-ignore] + ) + ) + else: + remaining_operators.append(c) + + if remaining_operators: + query = super().apply_column_operators(query, remaining_operators) + return query + @classmethod def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]: filterable = super().get_filterable_columns_and_operators() diff --git a/superset/daos/dashboard.py b/superset/daos/dashboard.py index d0d6a027a94..f473bc9492c 100644 --- a/superset/daos/dashboard.py +++ b/superset/daos/dashboard.py @@ -23,7 +23,7 @@ from typing import Any, Dict, List from flask import g from flask_appbuilder.models.sqla.interface import SQLAInterface -from sqlalchemy import select +from sqlalchemy import or_, select from sqlalchemy.orm import Query from superset import is_feature_enabled, security_manager @@ -38,7 +38,7 @@ from superset.dashboards.filters import DashboardAccessFilter, is_uuid from superset.exceptions import SupersetSecurityException from superset.extensions import db from superset.models.core import FavStar, FavStarClassName -from superset.models.dashboard import Dashboard, id_or_slug_filter +from superset.models.dashboard import Dashboard, dashboard_user, id_or_slug_filter from superset.models.embedded_dashboard import EmbeddedDashboard from superset.models.slice import Slice from superset.utils import json @@ -79,8 +79,6 @@ class DashboardDAO(BaseDAO[Dashboard]): if not isinstance(c, ColumnOperator): c = ColumnOperator.model_validate(c) if c.col == "owner": - from superset.models.dashboard import dashboard_user - operator_enum = ColumnOperatorEnum(c.opr) subq = select(dashboard_user.c.dashboard_id).where( operator_enum.apply(dashboard_user.c.user_id, c.value) @@ -88,6 +86,20 @@ class DashboardDAO(BaseDAO[Dashboard]): query = query.filter( Dashboard.id.in_(subq) # type: ignore[attr-defined,unused-ignore] ) + elif c.col == "created_by_fk_or_owner": + if c.opr != "eq": + raise ValueError( + f"created_by_fk_or_owner only supports 'eq'; got '{c.opr}'" + ) + owner_subq = select(dashboard_user.c.dashboard_id).where( + dashboard_user.c.user_id == c.value + ) + query = query.filter( + or_( + Dashboard.created_by_fk == c.value, # type: ignore[attr-defined,unused-ignore] + Dashboard.id.in_(owner_subq), # type: ignore[attr-defined,unused-ignore] + ) + ) elif c.col == "favorite": user_id = get_user_id() fav_subq = select(FavStar.obj_id).where( diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py index 6f4bbc51320..1822fd71186 100644 --- a/superset/daos/dataset.py +++ b/superset/daos/dataset.py @@ -21,11 +21,16 @@ from datetime import datetime from typing import Any, Dict, List import dateutil.parser -from sqlalchemy import select +from sqlalchemy import or_, select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Query -from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.connectors.sqla.models import ( + SqlaTable, + sqlatable_user, + SqlMetric, + TableColumn, +) from superset.daos.base import BaseDAO, ColumnOperator, ColumnOperatorEnum from superset.extensions import db from superset.models.core import Database @@ -79,8 +84,6 @@ class DatasetDAO(BaseDAO[SqlaTable]): ) query = query.filter(SqlaTable.database_id.in_(subq)) elif c.col == "owner": - from superset.connectors.sqla.models import sqlatable_user - operator_enum = ColumnOperatorEnum(c.opr) subq = select(sqlatable_user.c.table_id).where( operator_enum.apply(sqlatable_user.c.user_id, c.value) @@ -88,6 +91,20 @@ class DatasetDAO(BaseDAO[SqlaTable]): query = query.filter( SqlaTable.id.in_(subq) # type: ignore[attr-defined,unused-ignore] ) + elif c.col == "created_by_fk_or_owner": + if c.opr != "eq": + raise ValueError( + f"created_by_fk_or_owner only supports 'eq'; got '{c.opr}'" + ) + owner_subq = select(sqlatable_user.c.table_id).where( + sqlatable_user.c.user_id == c.value + ) + query = query.filter( + or_( + SqlaTable.created_by_fk == c.value, # type: ignore[attr-defined,unused-ignore] + SqlaTable.id.in_(owner_subq), # type: ignore[attr-defined,unused-ignore] + ) + ) else: remaining_operators.append(c) diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index d6b699e650f..e3490450187 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -142,6 +142,26 @@ To create a chart: "config": {{...}}, "save_chart": true }}) -> save permanently +To find your own charts/dashboards/datasets/databases: +- list_charts(request={{"created_by_me": true}}) — items you created +- list_dashboards(request={{"created_by_me": true}}) — items you created +- list_datasets(request={{"created_by_me": true}}) — items you created +- list_databases(request={{"created_by_me": true}}) — items you created + +To find items where you are listed as an owner (edit access): +- list_charts(request={{"owned_by_me": true}}) +- list_dashboards(request={{"owned_by_me": true}}) +- list_datasets(request={{"owned_by_me": true}}) + +To find all items you have any connection to (created OR own): +- list_charts(request={{"created_by_me": true, "owned_by_me": true}}) +- list_dashboards(request={{"created_by_me": true, "owned_by_me": true}}) +- list_datasets(request={{"created_by_me": true, "owned_by_me": true}}) + +Use created_by_me for authorship, owned_by_me for edit ownership, or both +together for the union. All flags can be combined with 'filters' but not +with 'search'. + To explore data with SQL: 1. list_datasets(request={{}}) -> find a dataset and note its database_id 2. execute_sql(request={{"database_id": , "sql": "SELECT ..."}}) @@ -202,6 +222,9 @@ Query Examples: list_charts(request={{"filters": [{{"col": "viz_type", "opr": "sw", "value": "echarts_timeseries"}}]}}) - Search by name: list_charts(request={{"search": "sales"}}) +- My charts: list_charts(request={{"created_by_me": true}}) +- My dashboards: list_dashboards(request={{"created_by_me": true}}) +- My databases: list_databases(request={{"created_by_me": true}}) To modify an existing chart (add filters, change metrics, etc.): 1. get_chart_info(request={{"identifier": }}) -> examine current configuration diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index b84b5068e00..decd3c46921 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -44,8 +44,10 @@ from superset.constants import TimeGrain from superset.daos.base import ColumnOperator, ColumnOperatorEnum from superset.mcp_service.common.cache_schemas import ( CacheStatus, + CreatedByMeMixin, FormDataCacheControl, MetadataCacheControl, + OwnedByMeMixin, QueryCacheControl, ) from superset.mcp_service.common.error_schemas import ChartGenerationError @@ -1262,7 +1264,7 @@ ChartConfig = Annotated[ # giving LLMs enough context to construct valid configs. -class ListChartsRequest(MetadataCacheControl): +class ListChartsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl): """Request schema for list_charts with clear, unambiguous types.""" filters: Annotated[ @@ -1342,8 +1344,7 @@ class ListChartsRequest(MetadataCacheControl): @model_validator(mode="after") def validate_search_and_filters(self) -> "ListChartsRequest": - """Prevent using both search and filters simultaneously to avoid query - conflicts.""" + """Prevent using both search and filters simultaneously.""" if self.search and self.filters: raise ValueError( "Cannot use both 'search' and 'filters' parameters simultaneously. " diff --git a/superset/mcp_service/chart/tool/list_charts.py b/superset/mcp_service/chart/tool/list_charts.py index f18b5c08f73..c8d75f6054c 100644 --- a/superset/mcp_service/chart/tool/list_charts.py +++ b/superset/mcp_service/chart/tool/list_charts.py @@ -169,6 +169,8 @@ async def list_charts( order_direction=request.order_direction, page=max(request.page - 1, 0), page_size=request.page_size, + created_by_me=request.created_by_me, + owned_by_me=request.owned_by_me, ) count = len(result.charts) if hasattr(result, "charts") else 0 total_pages = getattr(result, "total_pages", None) diff --git a/superset/mcp_service/common/cache_schemas.py b/superset/mcp_service/common/cache_schemas.py index bf51bbf49b0..7cea16f899e 100644 --- a/superset/mcp_service/common/cache_schemas.py +++ b/superset/mcp_service/common/cache_schemas.py @@ -23,7 +23,9 @@ existing cache infrastructure including query result cache, metadata cache, form data cache, and dashboard cache. """ -from pydantic import BaseModel, Field +from typing import Annotated, Any + +from pydantic import BaseModel, Field, model_validator class CacheControlMixin(BaseModel): @@ -83,6 +85,69 @@ class FormDataCacheControl(CacheControlMixin): ) +class CreatedByMeMixin(BaseModel): + """Mixin that adds a created_by_me filter flag to list request schemas. + + Provides a clean caller-facing alternative to exposing foreign key IDs. + The server translates the flag into the appropriate FK filter and injects + the current user's ID automatically. + """ + + created_by_me: Annotated[ + bool, + Field( + default=False, + description=( + "When true, return only items created by the current user. " + "Can be combined with 'filters' but not with 'search'." + ), + ), + ] + + @model_validator(mode="after") + def _validate_created_by_me_with_search(self) -> Any: + if getattr(self, "search", None) and self.created_by_me: + raise ValueError( + "'created_by_me' cannot be combined with 'search'. " + "Use 'created_by_me' alone or with 'filters'." + ) + return self + + +class OwnedByMeMixin(BaseModel): + """Mixin that adds an owned_by_me filter flag to list request schemas. + + Provides a clean caller-facing alternative to exposing M2M owner IDs. + The server translates the flag into the appropriate owner filter and injects + the current user's ID automatically. + + When combined with created_by_me, returns items where the current user is + either the creator OR an owner (union, not intersection). + """ + + owned_by_me: Annotated[ + bool, + Field( + default=False, + description=( + "When true, return only items where the current user is listed as " + "an owner. Can be combined with 'filters' but not with 'search'. " + "Can be combined with 'created_by_me' to return items where the " + "current user is either the creator or an owner." + ), + ), + ] + + @model_validator(mode="after") + def _validate_owned_by_me(self) -> Any: + if getattr(self, "search", None) and self.owned_by_me: + raise ValueError( + "'owned_by_me' cannot be combined with 'search'. " + "Use 'owned_by_me' alone or with 'filters'." + ) + return self + + class CacheStatus(BaseModel): """ Information about cache usage in tool responses. diff --git a/superset/mcp_service/dashboard/schemas.py b/superset/mcp_service/dashboard/schemas.py index e8f2009126c..268df90f87f 100644 --- a/superset/mcp_service/dashboard/schemas.py +++ b/superset/mcp_service/dashboard/schemas.py @@ -85,7 +85,11 @@ if TYPE_CHECKING: from superset.models.dashboard import Dashboard from superset.daos.base import ColumnOperator, ColumnOperatorEnum -from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.common.cache_schemas import ( + CreatedByMeMixin, + MetadataCacheControl, + OwnedByMeMixin, +) from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE from superset.mcp_service.privacy import ( filter_user_directory_fields, @@ -163,8 +167,7 @@ class DashboardFilter(ColumnOperator): ..., description=( "Column to filter on. Use " - "get_schema(model_type='dashboard') for available " - "filter columns." + "get_schema(model_type='dashboard') for available filter columns." ), ) opr: ColumnOperatorEnum = Field( @@ -177,7 +180,7 @@ class DashboardFilter(ColumnOperator): ) -class ListDashboardsRequest(MetadataCacheControl): +class ListDashboardsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl): """Request schema for list_dashboards with clear, unambiguous types.""" filters: Annotated[ @@ -257,8 +260,7 @@ class ListDashboardsRequest(MetadataCacheControl): @model_validator(mode="after") def validate_search_and_filters(self) -> "ListDashboardsRequest": - """Prevent using both search and filters simultaneously to avoid query - conflicts.""" + """Prevent using both search and filters simultaneously.""" if self.search and self.filters: raise ValueError( "Cannot use both 'search' and 'filters' parameters simultaneously. " diff --git a/superset/mcp_service/dashboard/tool/list_dashboards.py b/superset/mcp_service/dashboard/tool/list_dashboards.py index ded7d076144..7f8b0ee850c 100644 --- a/superset/mcp_service/dashboard/tool/list_dashboards.py +++ b/superset/mcp_service/dashboard/tool/list_dashboards.py @@ -145,6 +145,8 @@ async def list_dashboards( order_direction=request.order_direction, page=max(request.page - 1, 0), page_size=request.page_size, + created_by_me=request.created_by_me, + owned_by_me=request.owned_by_me, ) count = len(result.dashboards) if hasattr(result, "dashboards") else 0 total_pages = getattr(result, "total_pages", None) diff --git a/superset/mcp_service/database/schemas.py b/superset/mcp_service/database/schemas.py index 60b8ed5c100..020421550cc 100644 --- a/superset/mcp_service/database/schemas.py +++ b/superset/mcp_service/database/schemas.py @@ -36,7 +36,10 @@ from pydantic import ( ) from superset.daos.base import ColumnOperator, ColumnOperatorEnum -from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.common.cache_schemas import ( + CreatedByMeMixin, + MetadataCacheControl, +) from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE from superset.mcp_service.privacy import filter_user_directory_fields from superset.mcp_service.system.schemas import PaginationInfo @@ -59,14 +62,10 @@ class DatabaseFilter(ColumnOperator): "database_name", "expose_in_sqllab", "allow_file_upload", - "created_by_fk", - "changed_by_fk", ] = Field( ..., description="Column to filter on. Use get_schema(model_type='database') for " - "available filter columns. Use created_by_fk with the user " - "ID from get_instance_info's current_user to find " - "databases created by a specific user.", + "available filter columns.", ) opr: ColumnOperatorEnum = Field( ..., @@ -188,7 +187,7 @@ class DatabaseList(BaseModel): model_config = ConfigDict(ser_json_timedelta="iso8601") -class ListDatabasesRequest(MetadataCacheControl): +class ListDatabasesRequest(CreatedByMeMixin, MetadataCacheControl): """Request schema for list_databases with clear, unambiguous types.""" filters: Annotated[ diff --git a/superset/mcp_service/database/tool/list_databases.py b/superset/mcp_service/database/tool/list_databases.py index 6f3959f4528..b03949146ed 100644 --- a/superset/mcp_service/database/tool/list_databases.py +++ b/superset/mcp_service/database/tool/list_databases.py @@ -154,6 +154,7 @@ async def list_databases( order_direction=request.order_direction, page=max(request.page - 1, 0), page_size=request.page_size, + created_by_me=request.created_by_me, ) await ctx.info( diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index d09b312aa41..bbfe018dbfc 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -36,7 +36,11 @@ from pydantic import ( ) from superset.daos.base import ColumnOperator, ColumnOperatorEnum -from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.common.cache_schemas import ( + CreatedByMeMixin, + MetadataCacheControl, + OwnedByMeMixin, +) from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE from superset.mcp_service.privacy import filter_user_directory_fields from superset.mcp_service.system.schemas import ( @@ -213,7 +217,7 @@ class DatasetList(BaseModel): model_config = ConfigDict(ser_json_timedelta="iso8601") -class ListDatasetsRequest(MetadataCacheControl): +class ListDatasetsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl): """Request schema for list_datasets with clear, unambiguous types.""" filters: Annotated[ @@ -266,8 +270,7 @@ class ListDatasetsRequest(MetadataCacheControl): @model_validator(mode="after") def validate_search_and_filters(self) -> "ListDatasetsRequest": - """Prevent using both search and filters simultaneously to avoid query - conflicts.""" + """Prevent using both search and filters simultaneously.""" if self.search and self.filters: raise ValueError( "Cannot use both 'search' and 'filters' parameters simultaneously. " diff --git a/superset/mcp_service/dataset/tool/list_datasets.py b/superset/mcp_service/dataset/tool/list_datasets.py index 945f9fd3c08..84b7a71d15f 100644 --- a/superset/mcp_service/dataset/tool/list_datasets.py +++ b/superset/mcp_service/dataset/tool/list_datasets.py @@ -179,6 +179,8 @@ async def list_datasets( order_direction=request.order_direction, page=max(request.page - 1, 0), page_size=request.page_size, + created_by_me=request.created_by_me, + owned_by_me=request.owned_by_me, ) await ctx.info( diff --git a/superset/mcp_service/mcp_core.py b/superset/mcp_service/mcp_core.py index 28198fef077..fc8d584002f 100644 --- a/superset/mcp_service/mcp_core.py +++ b/superset/mcp_service/mcp_core.py @@ -19,18 +19,26 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Any, Callable, Dict, Generic, List, Literal, Type, TypeVar from pydantic import BaseModel -from superset.daos.base import BaseDAO -from superset.mcp_service.constants import ModelType +from superset.daos.base import BaseDAO, ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.constants import MAX_PAGE_SIZE, ModelType from superset.mcp_service.privacy import ( filter_user_directory_columns, + SELF_REFERENCING_FILTER_COLUMNS, USER_DIRECTORY_FIELDS, ) +from superset.mcp_service.system.schemas import PaginationInfo from superset.mcp_service.utils import _is_uuid +from superset.mcp_service.utils.permissions_utils import get_current_user +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_passthrough, +) +from superset.utils import json # Type variables for generic model tools T = TypeVar("T") # For model objects @@ -153,8 +161,6 @@ class ModelListCore(BaseCore, Generic[L]): if not select_columns: return self.default_columns, list(self.default_columns) - from superset.mcp_service.utils.schema_utils import parse_json_or_list - parsed_columns = parse_json_or_list(select_columns, param_name="select_columns") columns_to_load = filter_user_directory_columns(parsed_columns) if not columns_to_load: @@ -179,6 +185,44 @@ class ModelListCore(BaseCore, Generic[L]): f"Allowed columns: {', '.join(self._sortable_columns)}" ) + @staticmethod + def _prepend_self_lookup_filters( + filters: Any, + created_by_me: bool, + owned_by_me: bool, + user: Any, + ) -> Any: + """Translate created_by_me/owned_by_me flags into ColumnOperator filters. + + Validates authentication and injects the current user's ID in one step, + so no placeholder value ever reaches the DAO layer. + + When both flags are set, a single combined OR filter is used so results + include items where the user is either the creator or an owner. + """ + if not (created_by_me or owned_by_me): + return filters + + if not user or not getattr(user, "is_authenticated", False): + raise ValueError("This operation requires an authenticated user") + + user_id: int = user.id + extra: ColumnOperator + if created_by_me and owned_by_me: + extra = ColumnOperator( + col="created_by_fk_or_owner", opr="eq", value=user_id + ) + elif created_by_me: + extra = ColumnOperator(col="created_by_fk", opr="eq", value=user_id) + else: + extra = ColumnOperator(col="owner", opr="eq", value=user_id) + + if filters is None: + return [extra] + if isinstance(filters, list): + return [extra] + filters + return [extra, filters] + def run_tool( self, filters: Any | None = None, @@ -188,19 +232,19 @@ class ModelListCore(BaseCore, Generic[L]): order_direction: Literal["asc", "desc"] | None = "asc", page: int = 0, page_size: int = 10, + created_by_me: bool = False, + owned_by_me: bool = False, ) -> L: - from superset.mcp_service.constants import MAX_PAGE_SIZE - # Clamp page_size to MAX_PAGE_SIZE as defense-in-depth page_size = min(page_size, MAX_PAGE_SIZE) # Parse filters using generic utility (accepts JSON string or object) - from superset.mcp_service.utils.schema_utils import ( - parse_json_or_passthrough, - ) - filters = parse_json_or_passthrough(filters, param_name="filters") + filters = self._prepend_self_lookup_filters( + filters, created_by_me, owned_by_me, get_current_user() + ) + # Parse select_columns using generic utility (accepts JSON, list, or CSV) columns_requested, columns_to_load = self._get_columns_to_load(select_columns) @@ -236,7 +280,6 @@ class ModelListCore(BaseCore, Generic[L]): if obj is not None: item_objs.append(obj) total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 0 - from superset.mcp_service.system.schemas import PaginationInfo # Report 1-based page in response to match the 1-based input convention # used by all list tool wrappers (list_charts, list_datasets, etc.) @@ -271,7 +314,12 @@ class ModelListCore(BaseCore, Generic[L]): "columns_loaded": columns_to_load, "columns_available": self.all_columns, "sortable_columns": self.sortable_columns, - "filters_applied": filters if isinstance(filters, list) else [], + "filters_applied": [ + f + for f in (filters if isinstance(filters, list) else []) + if (f.get("col") if isinstance(f, dict) else getattr(f, "col", None)) + not in SELF_REFERENCING_FILTER_COLUMNS + ], "pagination": pagination_info, "timestamp": datetime.now(timezone.utc), } @@ -433,10 +481,6 @@ class InstanceInfoCore(BaseCore): self, base_counts: Dict[str, int] ) -> Dict[str, Dict[str, int]]: """Calculate time-based metrics for recent activity.""" - from datetime import datetime, timedelta, timezone - - from superset.daos.base import ColumnOperator, ColumnOperatorEnum - now = datetime.now(timezone.utc) time_metrics = {} @@ -521,8 +565,6 @@ class InstanceInfoCore(BaseCore): def get_resource(self) -> str: """Resource interface for generating instance metadata as JSON.""" - from superset.utils import json - instance_info = self._generate_instance_info() return json.dumps(instance_info.model_dump(), indent=2) @@ -535,8 +577,6 @@ class InstanceInfoCore(BaseCore): custom_metrics = self._calculate_custom_metrics(base_counts, time_metrics) # Combine all data with fallbacks for required fields - from datetime import datetime, timezone - response_data = { **base_counts, **time_metrics, diff --git a/superset/mcp_service/privacy.py b/superset/mcp_service/privacy.py index 25b34a5b6c8..4f98775332f 100644 --- a/superset/mcp_service/privacy.py +++ b/superset/mcp_service/privacy.py @@ -44,6 +44,13 @@ USER_DIRECTORY_FIELDS = frozenset( } ) +# User-directory columns that are valid as filter inputs even though they are +# hidden from response payloads and select-column surfaces. The system injects +# the correct value server-side, so callers never need to supply user IDs. +SELF_REFERENCING_FILTER_COLUMNS = frozenset( + {"created_by_fk", "owner", "created_by_fk_or_owner"} +) + DATA_MODEL_METADATA_ACCESS_ATTR = "_requires_data_model_metadata_access" DATA_MODEL_METADATA_ERROR_TYPE = "DataModelMetadataRestricted" DATA_MODEL_METADATA_PRIVACY_SCOPE = "data_model" @@ -124,6 +131,30 @@ def user_can_view_data_model_metadata() -> bool: return False +def inject_current_user_for_self_referencing_filters(filters: Any, user: Any) -> Any: + """Replace the value of any self-referencing filter with the current user's ID. + + Callers specify the column and operator; the system fills in the value. + This prevents enumeration of other users' content. + """ + if not filters: + return filters + filter_list = filters if isinstance(filters, list) else [filters] + result = [] + for f in filter_list: + col = f.get("col") if isinstance(f, dict) else getattr(f, "col", None) + if col in SELF_REFERENCING_FILTER_COLUMNS: + if not user or not getattr(user, "is_authenticated", False): + raise ValueError("This operation requires an authenticated user") + f = ( + {**f, "value": user.id} + if isinstance(f, dict) + else f.model_copy(update={"value": user.id}) + ) + result.append(f) + return result + + def filter_user_directory_fields(data: dict[str, Any]) -> dict[str, Any]: """Remove fields that expose users, roles, owners, or access metadata.""" return { diff --git a/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py b/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py index e85dab43156..aeae501a756 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_list_charts.py @@ -318,3 +318,68 @@ class TestChartDataModelMetadataPrivacy: data = json.loads(result.content[0].text) assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE + + +class TestListChartsCreatedByMe: + """Tests for the created_by_me flag on ListChartsRequest.""" + + def test_created_by_me_default_is_false(self): + request = ListChartsRequest() + assert request.created_by_me is False + + def test_created_by_me_true_accepted(self): + request = ListChartsRequest(created_by_me=True) + assert request.created_by_me is True + + def test_created_by_me_combined_with_filters(self): + request = ListChartsRequest( + created_by_me=True, + filters=[ChartFilter(col="slice_name", opr="sw", value="My")], + ) + assert request.created_by_me is True + assert len(request.filters) == 1 + + def test_created_by_me_with_search_raises(self): + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="created_by_me"): + ListChartsRequest(created_by_me=True, search="My charts") + + def test_chart_filter_rejects_created_by_fk(self): + """created_by_fk is not a public filter column; use created_by_me instead.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + ChartFilter(col="created_by_fk", opr="eq", value=1) + + +class TestListChartsOwnedByMe: + """Tests for the owned_by_me flag on ListChartsRequest.""" + + def test_owned_by_me_default_is_false(self): + request = ListChartsRequest() + assert request.owned_by_me is False + + def test_owned_by_me_true_accepted(self): + request = ListChartsRequest(owned_by_me=True) + assert request.owned_by_me is True + + def test_owned_by_me_combined_with_filters(self): + request = ListChartsRequest( + owned_by_me=True, + filters=[ChartFilter(col="slice_name", opr="sw", value="My")], + ) + assert request.owned_by_me is True + assert len(request.filters) == 1 + + def test_owned_by_me_with_search_raises(self): + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="owned_by_me"): + ListChartsRequest(owned_by_me=True, search="My charts") + + def test_owned_by_me_and_created_by_me_allowed(self): + """Both flags together are valid (OR logic — creator or owner).""" + request = ListChartsRequest(owned_by_me=True, created_by_me=True) + assert request.owned_by_me is True + assert request.created_by_me is True diff --git a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_tools.py b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_tools.py index 05d224fc513..66bd30b9fcd 100644 --- a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_tools.py +++ b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_tools.py @@ -29,6 +29,7 @@ from fastmcp.exceptions import ToolError from superset.mcp_service.app import mcp from superset.mcp_service.dashboard.schemas import ( + DashboardFilter, ListDashboardsRequest, ) from superset.utils import json @@ -983,3 +984,68 @@ class TestDashboardSortableColumns: assert "Sortable columns for order_column:" in list_dashboards.__doc__ for col in SORTABLE_DASHBOARD_COLUMNS: assert col in list_dashboards.__doc__ + + +class TestListDashboardsCreatedByMe: + """Tests for the created_by_me flag on ListDashboardsRequest.""" + + def test_created_by_me_default_is_false(self): + request = ListDashboardsRequest() + assert request.created_by_me is False + + def test_created_by_me_true_accepted(self): + request = ListDashboardsRequest(created_by_me=True) + assert request.created_by_me is True + + def test_created_by_me_combined_with_filters(self): + request = ListDashboardsRequest( + created_by_me=True, + filters=[DashboardFilter(col="published", opr="eq", value=True)], + ) + assert request.created_by_me is True + assert len(request.filters) == 1 + + def test_created_by_me_with_search_raises(self): + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="created_by_me"): + ListDashboardsRequest(created_by_me=True, search="My dashboards") + + def test_dashboard_filter_rejects_created_by_fk(self): + """created_by_fk is not a public filter column; use created_by_me instead.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + DashboardFilter(col="created_by_fk", opr="eq", value=1) + + +class TestListDashboardsOwnedByMe: + """Tests for the owned_by_me flag on ListDashboardsRequest.""" + + def test_owned_by_me_default_is_false(self): + request = ListDashboardsRequest() + assert request.owned_by_me is False + + def test_owned_by_me_true_accepted(self): + request = ListDashboardsRequest(owned_by_me=True) + assert request.owned_by_me is True + + def test_owned_by_me_combined_with_filters(self): + request = ListDashboardsRequest( + owned_by_me=True, + filters=[DashboardFilter(col="published", opr="eq", value=True)], + ) + assert request.owned_by_me is True + assert len(request.filters) == 1 + + def test_owned_by_me_with_search_raises(self): + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="owned_by_me"): + ListDashboardsRequest(owned_by_me=True, search="My dashboards") + + def test_owned_by_me_and_created_by_me_allowed(self): + """Both flags together are valid (OR logic — creator or owner).""" + request = ListDashboardsRequest(owned_by_me=True, created_by_me=True) + assert request.owned_by_me is True + assert request.created_by_me is True diff --git a/tests/unit_tests/mcp_service/database/tool/test_database_tools.py b/tests/unit_tests/mcp_service/database/tool/test_database_tools.py index 17091c75b2d..8cb494e3092 100644 --- a/tests/unit_tests/mcp_service/database/tool/test_database_tools.py +++ b/tests/unit_tests/mcp_service/database/tool/test_database_tools.py @@ -43,15 +43,16 @@ get_database_info_module = importlib.import_module( class TestDatabaseFilterSchema: """Tests for DatabaseFilter schema — filterable columns.""" - def test_created_by_fk_is_valid_filter_column(self): - """created_by_fk must be accepted as a filter column.""" - f = DatabaseFilter(col="created_by_fk", opr="eq", value=1) - assert f.col == "created_by_fk" + def test_created_by_fk_is_rejected_as_filter_column(self): + """created_by_fk is not a public filter column; use created_by_me instead.""" + with pytest.raises(ValidationError): + DatabaseFilter(col="created_by_fk", opr="eq", value=1) - def test_changed_by_fk_is_valid_filter_column(self): - """changed_by_fk must be accepted as a filter column.""" - f = DatabaseFilter(col="changed_by_fk", opr="eq", value=1) - assert f.col == "changed_by_fk" + def test_changed_by_fk_is_rejected_as_filter_column(self): + """changed_by_fk is not a public filter column; it exposes a user enumeration + vector (caller can probe which databases a given user ID has touched).""" + with pytest.raises(ValidationError): + DatabaseFilter(col="changed_by_fk", opr="eq", value=1) def test_invalid_filter_column_rejected(self): """Columns not in the Literal set must be rejected.""" @@ -269,11 +270,10 @@ async def test_list_databases_does_not_expose_user_directory_fields( def test_database_filter_rejects_user_directory_fields() -> None: - """Test user directory string fields cannot be used for database filters. + """Test user directory fields cannot be used for database filters. - created_by_fk / changed_by_fk are integer FK IDs and ARE valid filter - columns. The user-directory *string* fields (created_by, created_by_name, - etc.) must still be rejected. + All FK columns (created_by_fk, changed_by_fk) and user-directory string + fields (created_by, created_by_name, etc.) must be rejected. """ with pytest.raises(ValidationError, match="created_by_name"): ListDatabasesRequest( @@ -281,6 +281,20 @@ def test_database_filter_rejects_user_directory_fields() -> None: ) +def test_database_filter_rejects_created_by_fk() -> None: + """created_by_fk is no longer a valid filter column; use created_by_me instead.""" + with pytest.raises(ValidationError, match="created_by_fk"): + ListDatabasesRequest( + filters=[{"col": "created_by_fk", "opr": "eq", "value": 0}], + ) + + +def test_database_request_accepts_created_by_me() -> None: + """created_by_me=True is the correct way to filter by current user.""" + request = ListDatabasesRequest(created_by_me=True) + assert request.created_by_me is True + + @patch("superset.daos.database.DatabaseDAO.list") @pytest.mark.asyncio async def test_list_databases_api_error(mock_list, mcp_server): diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py b/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py index 1177a60533c..c937b2607f1 100644 --- a/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py +++ b/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py @@ -28,6 +28,7 @@ from fastmcp.exceptions import ToolError from superset.mcp_service.app import mcp from superset.mcp_service.dataset.schemas import ( CreateVirtualDatasetRequest, + DatasetFilter, ListDatasetsRequest, ) from superset.mcp_service.privacy import ( @@ -1890,3 +1891,68 @@ async def test_create_virtual_dataset_optional_fields_forwarded( assert props["schema"] == "public" assert props["catalog"] == "main" assert props["description"] == "A test dataset" + + +class TestListDatasetsCreatedByMe: + """Tests for the created_by_me flag on ListDatasetsRequest.""" + + def test_created_by_me_default_is_false(self): + request = ListDatasetsRequest() + assert request.created_by_me is False + + def test_created_by_me_true_accepted(self): + request = ListDatasetsRequest(created_by_me=True) + assert request.created_by_me is True + + def test_created_by_me_combined_with_filters(self): + request = ListDatasetsRequest( + created_by_me=True, + filters=[DatasetFilter(col="table_name", opr="sw", value="My")], + ) + assert request.created_by_me is True + assert len(request.filters) == 1 + + def test_created_by_me_with_search_raises(self): + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="created_by_me"): + ListDatasetsRequest(created_by_me=True, search="My tables") + + def test_dataset_filter_rejects_created_by_fk(self): + """created_by_fk is not a public filter column; use created_by_me instead.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + DatasetFilter(col="created_by_fk", opr="eq", value=1) + + +class TestListDatasetsOwnedByMe: + """Tests for the owned_by_me flag on ListDatasetsRequest.""" + + def test_owned_by_me_default_is_false(self): + request = ListDatasetsRequest() + assert request.owned_by_me is False + + def test_owned_by_me_true_accepted(self): + request = ListDatasetsRequest(owned_by_me=True) + assert request.owned_by_me is True + + def test_owned_by_me_combined_with_filters(self): + request = ListDatasetsRequest( + owned_by_me=True, + filters=[DatasetFilter(col="table_name", opr="sw", value="My")], + ) + assert request.owned_by_me is True + assert len(request.filters) == 1 + + def test_owned_by_me_with_search_raises(self): + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="owned_by_me"): + ListDatasetsRequest(owned_by_me=True, search="My datasets") + + def test_owned_by_me_and_created_by_me_allowed(self): + """Both flags together are valid (OR logic — creator or owner).""" + request = ListDatasetsRequest(owned_by_me=True, created_by_me=True) + assert request.owned_by_me is True + assert request.created_by_me is True diff --git a/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py b/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py index ca234eeb5ac..e10d534d158 100644 --- a/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py +++ b/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py @@ -461,7 +461,7 @@ class TestGetInstanceInfoCurrentUserViaMCP: def test_chart_filter_rejects_created_by_fk() -> None: - """Test that ChartFilter rejects user-directory columns.""" + """created_by_fk is not a valid ChartFilter column; use created_by_me instead.""" with pytest.raises(ValidationError): ChartFilter(col="created_by_fk", opr="eq", value=42) @@ -473,7 +473,7 @@ def test_chart_filter_rejects_invalid_column(): def test_dashboard_filter_rejects_created_by_fk(): - """Test that DashboardFilter rejects user-directory columns.""" + """created_by_fk is not a valid DashboardFilter column; use created_by_me.""" with pytest.raises(ValidationError): DashboardFilter(col="created_by_fk", opr="eq", value=42) diff --git a/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py b/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py index 54c4906eaea..3538f9acce6 100644 --- a/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py +++ b/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py @@ -18,6 +18,7 @@ from datetime import datetime from types import SimpleNamespace from typing import Any, Dict, List +from unittest.mock import Mock, patch import pytest from pydantic import BaseModel @@ -126,6 +127,45 @@ def test_model_list_tool_with_filters_and_columns(): assert "id" in result.columns_loaded +def test_model_list_tool_keeps_single_filter_when_created_by_me_is_used(): + current_user = Mock() + current_user.is_authenticated = True + current_user.id = 42 + + captured = {} + + class CapturingDAO: + @classmethod + def list(cls, column_operators=None, **kwargs): + captured["filters"] = column_operators + return [], 0 + + tool = ModelListCore( + dao_class=CapturingDAO, + output_schema=DummyOutputSchema, + item_serializer=dummy_serializer, + filter_type=None, + default_columns=["id", "name"], + search_columns=["name"], + list_field_name="items", + output_list_schema=DummyListSchema, + ) + + with patch( + "superset.mcp_service.mcp_core.get_current_user", + return_value=current_user, + ): + tool.run_tool( + filters={"col": "name", "opr": "eq", "value": "foo"}, + created_by_me=True, + ) + + assert len(captured["filters"]) == 2 + assert captured["filters"][0].col == "created_by_fk" + assert captured["filters"][0].value == 42 + assert captured["filters"][1] == {"col": "name", "opr": "eq", "value": "foo"} + + def test_model_list_tool_rejects_only_user_directory_select_columns(): tool = ModelListCore( dao_class=DummyDAO, @@ -177,6 +217,160 @@ def test_model_list_tool_allows_order_column_when_sortable_columns_not_declared( tool.run_tool(order_column="name") +def test_model_list_tool_injects_current_user_id_for_created_by_me(): + """created_by_me=True adds a created_by_fk filter with the current user's ID.""" + current_user = Mock() + current_user.is_authenticated = True + current_user.id = 42 + + captured = {} + + class CapturingDAO: + @classmethod + def list(cls, column_operators=None, **kwargs): + captured["filters"] = column_operators + return [], 0 + + tool = ModelListCore( + dao_class=CapturingDAO, + output_schema=DummyOutputSchema, + item_serializer=dummy_serializer, + filter_type=None, + default_columns=["id", "name"], + search_columns=["name"], + list_field_name="items", + output_list_schema=DummyListSchema, + ) + + with patch( + "superset.mcp_service.mcp_core.get_current_user", + return_value=current_user, + ): + tool.run_tool(created_by_me=True) + + assert captured["filters"][0].col == "created_by_fk" + assert captured["filters"][0].value == 42 + + +def test_model_list_tool_created_by_me_requires_authenticated_user(): + """created_by_me=True raises when no authenticated user is present.""" + current_user = Mock() + current_user.is_authenticated = False + + tool = ModelListCore( + dao_class=DummyDAO, + output_schema=DummyOutputSchema, + item_serializer=dummy_serializer, + filter_type=None, + default_columns=["id", "name"], + search_columns=["name"], + list_field_name="items", + output_list_schema=DummyListSchema, + ) + + with patch( + "superset.mcp_service.mcp_core.get_current_user", + return_value=current_user, + ): + with pytest.raises(ValueError, match="authenticated user"): + tool.run_tool(created_by_me=True) + + +def test_model_list_tool_injects_current_user_id_for_owned_by_me(): + """owned_by_me=True adds an owner filter with the current user's ID.""" + current_user = Mock() + current_user.is_authenticated = True + current_user.id = 99 + + captured = {} + + class CapturingDAO: + @classmethod + def list(cls, column_operators=None, **kwargs): + captured["filters"] = column_operators + return [], 0 + + tool = ModelListCore( + dao_class=CapturingDAO, + output_schema=DummyOutputSchema, + item_serializer=dummy_serializer, + filter_type=None, + default_columns=["id", "name"], + search_columns=["name"], + list_field_name="items", + output_list_schema=DummyListSchema, + ) + + with patch( + "superset.mcp_service.mcp_core.get_current_user", + return_value=current_user, + ): + tool.run_tool(owned_by_me=True) + + assert captured["filters"][0].col == "owner" + assert captured["filters"][0].value == 99 + + +def test_model_list_tool_both_flags_uses_combined_or_filter(): + """created_by_me=True + owned_by_me=True generates a single OR filter.""" + current_user = Mock() + current_user.is_authenticated = True + current_user.id = 55 + + captured = {} + + class CapturingDAO: + @classmethod + def list(cls, column_operators=None, **kwargs): + captured["filters"] = column_operators + return [], 0 + + tool = ModelListCore( + dao_class=CapturingDAO, + output_schema=DummyOutputSchema, + item_serializer=dummy_serializer, + filter_type=None, + default_columns=["id", "name"], + search_columns=["name"], + list_field_name="items", + output_list_schema=DummyListSchema, + ) + + with patch( + "superset.mcp_service.mcp_core.get_current_user", + return_value=current_user, + ): + tool.run_tool(created_by_me=True, owned_by_me=True) + + assert len(captured["filters"]) == 1 + assert captured["filters"][0].col == "created_by_fk_or_owner" + assert captured["filters"][0].value == 55 + + +def test_model_list_tool_owned_by_me_requires_authenticated_user(): + """owned_by_me=True raises when no authenticated user is present.""" + current_user = Mock() + current_user.is_authenticated = False + + tool = ModelListCore( + dao_class=DummyDAO, + output_schema=DummyOutputSchema, + item_serializer=dummy_serializer, + filter_type=None, + default_columns=["id", "name"], + search_columns=["name"], + list_field_name="items", + output_list_schema=DummyListSchema, + ) + + with patch( + "superset.mcp_service.mcp_core.get_current_user", + return_value=current_user, + ): + with pytest.raises(ValueError, match="authenticated user"): + tool.run_tool(owned_by_me=True) + + def test_user_directory_fields_include_last_saved_relationships(): assert "last_saved_by" in USER_DIRECTORY_FIELDS assert "last_saved_by_name" in USER_DIRECTORY_FIELDS