Files
superset2/superset/security/api.py
2026-02-12 09:32:19 -08:00

367 lines
12 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any
from flask import current_app, request, Response
from flask_appbuilder import expose
from flask_appbuilder.api import rison, safe, SQLAInterface
from flask_appbuilder.api.schemas import get_list_schema
from flask_appbuilder.security.decorators import permission_name, protect
from flask_appbuilder.security.sqla.models import RegisterUser, Role
from flask_wtf.csrf import generate_csrf
from marshmallow import EXCLUDE, fields, post_load, Schema, ValidationError
from sqlalchemy import asc, desc
from sqlalchemy.orm import selectinload
from superset.commands.dashboard.embedded.exceptions import (
EmbeddedDashboardNotFoundError,
)
from superset.commands.exceptions import ForbiddenError
from superset.exceptions import SupersetGenericErrorException
from superset.extensions import db, event_logger
from superset.security.guest_token import GuestTokenResourceType
from superset.views.base_api import (
BaseSupersetApi,
BaseSupersetModelRestApi,
statsd_metrics,
)
logger = logging.getLogger(__name__)
class PermissiveSchema(Schema):
"""
A marshmallow schema that ignores unexpected fields, instead of throwing an error.
"""
class Meta: # pylint: disable=too-few-public-methods
unknown = EXCLUDE
class UserSchema(PermissiveSchema):
username = fields.String()
first_name = fields.String()
last_name = fields.String()
class ResourceSchema(PermissiveSchema):
type = fields.Enum(GuestTokenResourceType, by_value=True, required=True)
id = fields.String(required=True)
@post_load
def convert_enum_to_value( # pylint: disable=unused-argument
self,
data: dict[str, Any],
**kwargs: Any,
) -> dict[str, Any]:
# we don't care about the enum, we want the value inside
data["type"] = data["type"].value
return data
class RlsRuleSchema(PermissiveSchema):
dataset = fields.Integer()
clause = fields.String(required=True) # todo other options?
class GuestTokenCreateSchema(PermissiveSchema):
user = fields.Nested(UserSchema)
resources = fields.List(fields.Nested(ResourceSchema), required=True)
rls = fields.List(fields.Nested(RlsRuleSchema), required=True)
class RoleResponseSchema(PermissiveSchema):
id = fields.Integer()
name = fields.String()
user_ids = fields.List(fields.Integer())
permission_ids = fields.List(fields.Integer())
class RolesResponseSchema(PermissiveSchema):
count = fields.Integer()
ids = fields.List(fields.Integer())
result = fields.List(fields.Nested(RoleResponseSchema))
guest_token_create_schema = GuestTokenCreateSchema()
class SecurityRestApi(BaseSupersetApi):
resource_name = "security"
allow_browser_login = True
openapi_spec_tag = "Security"
@expose("/csrf_token/", methods=("GET",))
@event_logger.log_this
@protect()
@safe
@statsd_metrics
@permission_name("read")
def csrf_token(self) -> Response:
"""Get the CSRF token.
---
get:
summary: Get the CSRF token
responses:
200:
description: Result contains the CSRF token
content:
application/json:
schema:
type: object
properties:
result:
type: string
401:
$ref: '#/components/responses/401'
500:
$ref: '#/components/responses/500'
"""
return self.response(200, result=generate_csrf())
@expose("/guest_token/", methods=("POST",))
@event_logger.log_this
@protect()
@safe
@statsd_metrics
@permission_name("grant_guest_token")
def guest_token(self) -> Response:
"""Get a guest token that can be used for auth in embedded Superset.
---
post:
summary: Get a guest token
requestBody:
description: Parameters for the guest token
required: true
content:
application/json:
schema: GuestTokenCreateSchema
responses:
200:
description: Result contains the guest token
content:
application/json:
schema:
type: object
properties:
token:
type: string
401:
$ref: '#/components/responses/401'
400:
$ref: '#/components/responses/400'
500:
$ref: '#/components/responses/500'
"""
try:
body = guest_token_create_schema.load(request.json)
self.appbuilder.sm.validate_guest_token_resources(body["resources"])
guest_token_validator_hook = current_app.config.get(
"GUEST_TOKEN_VALIDATOR_HOOK"
)
# Run validator to ensure the token parameters are OK.
if guest_token_validator_hook is not None:
if callable(guest_token_validator_hook):
if not guest_token_validator_hook(body):
raise ValidationError(message="Guest token validation failed")
else:
raise SupersetGenericErrorException(
message="Guest token validator hook not callable"
)
# TODO: Add generic validation:
# make sure username doesn't reference an existing user
# check rls rules for validity?
token = self.appbuilder.sm.create_guest_access_token(
body["user"], body["resources"], body["rls"]
)
return self.response(200, token=token)
except EmbeddedDashboardNotFoundError as error:
return self.response_400(message=error.message)
except ValidationError as error:
return self.response_400(message=error.messages)
class RoleRestAPI(BaseSupersetApi):
"""
APIs for listing roles with usersIds and permissionsIds and possibility to update
users of roles
"""
resource_name = "security/roles"
allow_browser_login = True
openapi_spec_tag = "Security Roles"
openapi_spec_component_schemas = (
RoleResponseSchema,
RolesResponseSchema,
)
@expose("/search/", methods=["GET"])
@event_logger.log_this
@protect()
@safe
@rison(get_list_schema)
@statsd_metrics
@permission_name("list_roles")
def get_list(self, **kwargs: Any) -> Response:
"""
List roles, including associated user IDs and permission IDs.
---
get:
summary: List roles
description: Fetch a paginated list of roles with user and permission IDs.
parameters:
- in: query
name: q
schema:
type: object
properties:
order_column:
type: string
enum: ["id", "name"]
default: "id"
order_direction:
type: string
enum: ["asc", "desc"]
default: "asc"
page:
type: integer
default: 0
page_size:
type: integer
default: 10
filters:
type: array
items:
type: object
properties:
col:
type: string
enum: ["user_ids", "permission_ids", "name"]
value:
type: string
responses:
200:
description: Successfully retrieved roles
content:
application/json:
schema: RolesResponseSchema
400:
description: Bad request (invalid input)
content:
application/json:
schema:
type: object
properties:
error:
type: string
403:
description: Forbidden
content:
application/json:
schema:
type: object
properties:
error:
type: string
"""
try:
args = kwargs.get("rison", {})
order_column = args.get("order_column", "id")
order_direction = args.get("order_direction", "asc")
valid_columns = ["id", "name"]
if order_column not in valid_columns:
return self.response_400(
message=f"Invalid order column: {order_column}"
)
order_by = getattr(Role, order_column)
order_by = asc(order_by) if order_direction == "asc" else desc(order_by)
page = args.get("page", 0)
page_size = args.get("page_size", 10)
query = db.session.query(Role).options(
selectinload(Role.permissions),
selectinload(Role.user),
selectinload(Role.groups),
)
filters = args.get("filters", [])
filter_dict = {f["col"]: f["value"] for f in filters if "col" in f}
if "user_ids" in filter_dict:
query = query.filter(Role.user.any(id=filter_dict["user_ids"]))
if "permission_ids" in filter_dict:
query = query.filter(
Role.permissions.any(id=filter_dict["permission_ids"])
)
if "group_ids" in filter_dict:
query = query.filter(Role.groups.any(id=filter_dict["group_ids"]))
if "name" in filter_dict:
query = query.filter(Role.name.ilike(f"%{filter_dict['name']}%"))
total_count = query.count()
roles = (
query.order_by(order_by).offset(page * page_size).limit(page_size).all()
)
return self.response(
200,
result=[
{
"id": role.id,
"name": role.name,
"user_ids": [user.id for user in role.user],
"permission_ids": [perm.id for perm in role.permissions],
"group_ids": [group.id for group in role.groups],
}
for role in roles
],
count=total_count,
ids=[role.id for role in roles],
)
except ForbiddenError as e:
return self.response_403(message=str(e))
except Exception as e:
return self.response_500(message=str(e))
class UserRegistrationsRestAPI(BaseSupersetModelRestApi):
"""
APIs for listing user registrations (Admin only)
"""
resource_name = "security/user_registrations"
datamodel = SQLAInterface(RegisterUser)
allow_browser_login = True
list_columns = [
"id",
"username",
"email",
"first_name",
"last_name",
"registration_date",
"registration_hash",
]