mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
367 lines
12 KiB
Python
367 lines
12 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import logging
|
|
from typing import Any
|
|
|
|
from flask import current_app, request, Response
|
|
from flask_appbuilder import expose
|
|
from flask_appbuilder.api import rison, safe, SQLAInterface
|
|
from flask_appbuilder.api.schemas import get_list_schema
|
|
from flask_appbuilder.security.decorators import permission_name, protect
|
|
from flask_appbuilder.security.sqla.models import RegisterUser, Role
|
|
from flask_wtf.csrf import generate_csrf
|
|
from marshmallow import EXCLUDE, fields, post_load, Schema, ValidationError
|
|
from sqlalchemy import asc, desc
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from superset.commands.dashboard.embedded.exceptions import (
|
|
EmbeddedDashboardNotFoundError,
|
|
)
|
|
from superset.commands.exceptions import ForbiddenError
|
|
from superset.exceptions import SupersetGenericErrorException
|
|
from superset.extensions import db, event_logger
|
|
from superset.security.guest_token import GuestTokenResourceType
|
|
from superset.views.base_api import (
|
|
BaseSupersetApi,
|
|
BaseSupersetModelRestApi,
|
|
statsd_metrics,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PermissiveSchema(Schema):
|
|
"""
|
|
A marshmallow schema that ignores unexpected fields, instead of throwing an error.
|
|
"""
|
|
|
|
class Meta: # pylint: disable=too-few-public-methods
|
|
unknown = EXCLUDE
|
|
|
|
|
|
class UserSchema(PermissiveSchema):
|
|
username = fields.String()
|
|
first_name = fields.String()
|
|
last_name = fields.String()
|
|
|
|
|
|
class ResourceSchema(PermissiveSchema):
|
|
type = fields.Enum(GuestTokenResourceType, by_value=True, required=True)
|
|
id = fields.String(required=True)
|
|
|
|
@post_load
|
|
def convert_enum_to_value( # pylint: disable=unused-argument
|
|
self,
|
|
data: dict[str, Any],
|
|
**kwargs: Any,
|
|
) -> dict[str, Any]:
|
|
# we don't care about the enum, we want the value inside
|
|
data["type"] = data["type"].value
|
|
return data
|
|
|
|
|
|
class RlsRuleSchema(PermissiveSchema):
|
|
dataset = fields.Integer()
|
|
clause = fields.String(required=True) # todo other options?
|
|
|
|
|
|
class GuestTokenCreateSchema(PermissiveSchema):
|
|
user = fields.Nested(UserSchema)
|
|
resources = fields.List(fields.Nested(ResourceSchema), required=True)
|
|
rls = fields.List(fields.Nested(RlsRuleSchema), required=True)
|
|
|
|
|
|
class RoleResponseSchema(PermissiveSchema):
|
|
id = fields.Integer()
|
|
name = fields.String()
|
|
user_ids = fields.List(fields.Integer())
|
|
permission_ids = fields.List(fields.Integer())
|
|
|
|
|
|
class RolesResponseSchema(PermissiveSchema):
|
|
count = fields.Integer()
|
|
ids = fields.List(fields.Integer())
|
|
result = fields.List(fields.Nested(RoleResponseSchema))
|
|
|
|
|
|
guest_token_create_schema = GuestTokenCreateSchema()
|
|
|
|
|
|
class SecurityRestApi(BaseSupersetApi):
|
|
resource_name = "security"
|
|
allow_browser_login = True
|
|
openapi_spec_tag = "Security"
|
|
|
|
@expose("/csrf_token/", methods=("GET",))
|
|
@event_logger.log_this
|
|
@protect()
|
|
@safe
|
|
@statsd_metrics
|
|
@permission_name("read")
|
|
def csrf_token(self) -> Response:
|
|
"""Get the CSRF token.
|
|
---
|
|
get:
|
|
summary: Get the CSRF token
|
|
responses:
|
|
200:
|
|
description: Result contains the CSRF token
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
result:
|
|
type: string
|
|
401:
|
|
$ref: '#/components/responses/401'
|
|
500:
|
|
$ref: '#/components/responses/500'
|
|
"""
|
|
return self.response(200, result=generate_csrf())
|
|
|
|
@expose("/guest_token/", methods=("POST",))
|
|
@event_logger.log_this
|
|
@protect()
|
|
@safe
|
|
@statsd_metrics
|
|
@permission_name("grant_guest_token")
|
|
def guest_token(self) -> Response:
|
|
"""Get a guest token that can be used for auth in embedded Superset.
|
|
---
|
|
post:
|
|
summary: Get a guest token
|
|
requestBody:
|
|
description: Parameters for the guest token
|
|
required: true
|
|
content:
|
|
application/json:
|
|
schema: GuestTokenCreateSchema
|
|
responses:
|
|
200:
|
|
description: Result contains the guest token
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
token:
|
|
type: string
|
|
401:
|
|
$ref: '#/components/responses/401'
|
|
400:
|
|
$ref: '#/components/responses/400'
|
|
500:
|
|
$ref: '#/components/responses/500'
|
|
"""
|
|
try:
|
|
body = guest_token_create_schema.load(request.json)
|
|
self.appbuilder.sm.validate_guest_token_resources(body["resources"])
|
|
guest_token_validator_hook = current_app.config.get(
|
|
"GUEST_TOKEN_VALIDATOR_HOOK"
|
|
)
|
|
# Run validator to ensure the token parameters are OK.
|
|
if guest_token_validator_hook is not None:
|
|
if callable(guest_token_validator_hook):
|
|
if not guest_token_validator_hook(body):
|
|
raise ValidationError(message="Guest token validation failed")
|
|
else:
|
|
raise SupersetGenericErrorException(
|
|
message="Guest token validator hook not callable"
|
|
)
|
|
# TODO: Add generic validation:
|
|
# make sure username doesn't reference an existing user
|
|
# check rls rules for validity?
|
|
token = self.appbuilder.sm.create_guest_access_token(
|
|
body["user"], body["resources"], body["rls"]
|
|
)
|
|
return self.response(200, token=token)
|
|
except EmbeddedDashboardNotFoundError as error:
|
|
return self.response_400(message=error.message)
|
|
except ValidationError as error:
|
|
return self.response_400(message=error.messages)
|
|
|
|
|
|
class RoleRestAPI(BaseSupersetApi):
|
|
"""
|
|
APIs for listing roles with usersIds and permissionsIds and possibility to update
|
|
users of roles
|
|
"""
|
|
|
|
resource_name = "security/roles"
|
|
allow_browser_login = True
|
|
openapi_spec_tag = "Security Roles"
|
|
openapi_spec_component_schemas = (
|
|
RoleResponseSchema,
|
|
RolesResponseSchema,
|
|
)
|
|
|
|
@expose("/search/", methods=["GET"])
|
|
@event_logger.log_this
|
|
@protect()
|
|
@safe
|
|
@rison(get_list_schema)
|
|
@statsd_metrics
|
|
@permission_name("list_roles")
|
|
def get_list(self, **kwargs: Any) -> Response:
|
|
"""
|
|
List roles, including associated user IDs and permission IDs.
|
|
|
|
---
|
|
get:
|
|
summary: List roles
|
|
description: Fetch a paginated list of roles with user and permission IDs.
|
|
parameters:
|
|
- in: query
|
|
name: q
|
|
schema:
|
|
type: object
|
|
properties:
|
|
order_column:
|
|
type: string
|
|
enum: ["id", "name"]
|
|
default: "id"
|
|
order_direction:
|
|
type: string
|
|
enum: ["asc", "desc"]
|
|
default: "asc"
|
|
page:
|
|
type: integer
|
|
default: 0
|
|
page_size:
|
|
type: integer
|
|
default: 10
|
|
filters:
|
|
type: array
|
|
items:
|
|
type: object
|
|
properties:
|
|
col:
|
|
type: string
|
|
enum: ["user_ids", "permission_ids", "name"]
|
|
value:
|
|
type: string
|
|
responses:
|
|
200:
|
|
description: Successfully retrieved roles
|
|
content:
|
|
application/json:
|
|
schema: RolesResponseSchema
|
|
400:
|
|
description: Bad request (invalid input)
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
error:
|
|
type: string
|
|
403:
|
|
description: Forbidden
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
error:
|
|
type: string
|
|
"""
|
|
try:
|
|
args = kwargs.get("rison", {})
|
|
order_column = args.get("order_column", "id")
|
|
order_direction = args.get("order_direction", "asc")
|
|
|
|
valid_columns = ["id", "name"]
|
|
if order_column not in valid_columns:
|
|
return self.response_400(
|
|
message=f"Invalid order column: {order_column}"
|
|
)
|
|
|
|
order_by = getattr(Role, order_column)
|
|
order_by = asc(order_by) if order_direction == "asc" else desc(order_by)
|
|
|
|
page = args.get("page", 0)
|
|
page_size = args.get("page_size", 10)
|
|
|
|
query = db.session.query(Role).options(
|
|
selectinload(Role.permissions),
|
|
selectinload(Role.user),
|
|
selectinload(Role.groups),
|
|
)
|
|
|
|
filters = args.get("filters", [])
|
|
filter_dict = {f["col"]: f["value"] for f in filters if "col" in f}
|
|
|
|
if "user_ids" in filter_dict:
|
|
query = query.filter(Role.user.any(id=filter_dict["user_ids"]))
|
|
|
|
if "permission_ids" in filter_dict:
|
|
query = query.filter(
|
|
Role.permissions.any(id=filter_dict["permission_ids"])
|
|
)
|
|
|
|
if "group_ids" in filter_dict:
|
|
query = query.filter(Role.groups.any(id=filter_dict["group_ids"]))
|
|
|
|
if "name" in filter_dict:
|
|
query = query.filter(Role.name.ilike(f"%{filter_dict['name']}%"))
|
|
|
|
total_count = query.count()
|
|
|
|
roles = (
|
|
query.order_by(order_by).offset(page * page_size).limit(page_size).all()
|
|
)
|
|
|
|
return self.response(
|
|
200,
|
|
result=[
|
|
{
|
|
"id": role.id,
|
|
"name": role.name,
|
|
"user_ids": [user.id for user in role.user],
|
|
"permission_ids": [perm.id for perm in role.permissions],
|
|
"group_ids": [group.id for group in role.groups],
|
|
}
|
|
for role in roles
|
|
],
|
|
count=total_count,
|
|
ids=[role.id for role in roles],
|
|
)
|
|
except ForbiddenError as e:
|
|
return self.response_403(message=str(e))
|
|
except Exception as e:
|
|
return self.response_500(message=str(e))
|
|
|
|
|
|
class UserRegistrationsRestAPI(BaseSupersetModelRestApi):
|
|
"""
|
|
APIs for listing user registrations (Admin only)
|
|
"""
|
|
|
|
resource_name = "security/user_registrations"
|
|
datamodel = SQLAInterface(RegisterUser)
|
|
allow_browser_login = True
|
|
list_columns = [
|
|
"id",
|
|
"username",
|
|
"email",
|
|
"first_name",
|
|
"last_name",
|
|
"registration_date",
|
|
"registration_hash",
|
|
]
|