mirror of
https://github.com/apache/superset.git
synced 2026-04-25 19:14:27 +00:00
366 lines
12 KiB
Python
366 lines
12 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import logging
|
|
from operator import and_
|
|
from typing import Any, Optional
|
|
|
|
from flask import g
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from superset.daos.base import BaseDAO
|
|
from superset.daos.exceptions import DAOCreateFailedError, DAODeleteFailedError
|
|
from superset.exceptions import MissingUserContextException
|
|
from superset.extensions import db
|
|
from superset.models.dashboard import Dashboard
|
|
from superset.models.slice import Slice
|
|
from superset.models.sql_lab import SavedQuery
|
|
from superset.tags.commands.exceptions import TagNotFoundError
|
|
from superset.tags.models import (
|
|
get_tag,
|
|
ObjectTypes,
|
|
Tag,
|
|
TaggedObject,
|
|
TagTypes,
|
|
user_favorite_tag_table,
|
|
)
|
|
from superset.utils.core import get_user_id
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TagDAO(BaseDAO[Tag]):
|
|
# base_filter = TagAccessFilter
|
|
|
|
@staticmethod
|
|
def validate_tag_name(tag_name: str) -> bool:
|
|
invalid_characters = [":", ","]
|
|
for invalid_character in invalid_characters:
|
|
if invalid_character in tag_name:
|
|
return False
|
|
return True
|
|
|
|
@staticmethod
|
|
def create_custom_tagged_objects(
|
|
object_type: ObjectTypes, object_id: int, tag_names: list[str]
|
|
) -> None:
|
|
tagged_objects = []
|
|
for name in tag_names:
|
|
if not TagDAO.validate_tag_name(name):
|
|
raise DAOCreateFailedError(
|
|
message="Invalid Tag Name (cannot contain ':' or ',')"
|
|
)
|
|
type_ = TagTypes.custom
|
|
tag_name = name.strip()
|
|
tag = TagDAO.get_by_name(tag_name, type_)
|
|
tagged_objects.append(
|
|
TaggedObject(object_id=object_id, object_type=object_type, tag=tag)
|
|
)
|
|
|
|
db.session.add_all(tagged_objects)
|
|
db.session.commit()
|
|
|
|
@staticmethod
|
|
def delete_tagged_object(
|
|
object_type: ObjectTypes, object_id: int, tag_name: str
|
|
) -> None:
|
|
"""
|
|
deletes a tagged object by the object_id, object_type, and tag_name
|
|
"""
|
|
tag = TagDAO.find_by_name(tag_name.strip())
|
|
if not tag:
|
|
raise DAODeleteFailedError(
|
|
message=f"Tag with name {tag_name} does not exist."
|
|
)
|
|
|
|
tagged_object = db.session.query(TaggedObject).filter(
|
|
TaggedObject.tag_id == tag.id,
|
|
TaggedObject.object_type == object_type,
|
|
TaggedObject.object_id == object_id,
|
|
)
|
|
if not tagged_object:
|
|
raise DAODeleteFailedError(
|
|
message=f'Tagged object with object_id: {object_id} \
|
|
object_type: {object_type} \
|
|
and tag name: "{tag_name}" could not be found'
|
|
)
|
|
try:
|
|
db.session.delete(tagged_object.one())
|
|
db.session.commit()
|
|
except SQLAlchemyError as ex: # pragma: no cover
|
|
db.session.rollback()
|
|
raise DAODeleteFailedError(exception=ex) from ex
|
|
|
|
@staticmethod
|
|
def delete_tags(tag_names: list[str]) -> None:
|
|
"""
|
|
deletes tags from a list of tag names
|
|
"""
|
|
tags_to_delete = []
|
|
for name in tag_names:
|
|
tag_name = name.strip()
|
|
if not TagDAO.find_by_name(tag_name):
|
|
raise DAODeleteFailedError(
|
|
message=f"Tag with name {tag_name} does not exist."
|
|
)
|
|
tags_to_delete.append(tag_name)
|
|
tag_objects = db.session.query(Tag).filter(Tag.name.in_(tags_to_delete))
|
|
for tag in tag_objects:
|
|
try:
|
|
db.session.delete(tag)
|
|
db.session.commit()
|
|
except SQLAlchemyError as ex: # pragma: no cover
|
|
db.session.rollback()
|
|
raise DAODeleteFailedError(exception=ex) from ex
|
|
|
|
@staticmethod
|
|
def get_by_name(name: str, type_: TagTypes = TagTypes.custom) -> Tag:
|
|
"""
|
|
returns a tag if one exists by that name, none otherwise.
|
|
important!: Creates a tag by that name if the tag is not found.
|
|
"""
|
|
tag = (
|
|
db.session.query(Tag)
|
|
.filter(Tag.name == name, Tag.type == type_.name)
|
|
.first()
|
|
)
|
|
if not tag:
|
|
tag = get_tag(name, db.session, type_)
|
|
return tag
|
|
|
|
@staticmethod
|
|
def find_by_name(name: str) -> Tag:
|
|
"""
|
|
returns a tag if one exists by that name, none otherwise.
|
|
Does NOT create a tag if the tag is not found.
|
|
"""
|
|
return db.session.query(Tag).filter(Tag.name == name).first()
|
|
|
|
@staticmethod
|
|
def find_tagged_object(
|
|
object_type: ObjectTypes, object_id: int, tag_id: int
|
|
) -> TaggedObject:
|
|
"""
|
|
returns a tagged object if one exists by that name, none otherwise.
|
|
"""
|
|
return (
|
|
db.session.query(TaggedObject)
|
|
.filter(
|
|
TaggedObject.tag_id == tag_id,
|
|
TaggedObject.object_id == object_id,
|
|
TaggedObject.object_type == object_type,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
@staticmethod
|
|
def get_tagged_objects_for_tags(
|
|
tags: Optional[list[str]] = None, obj_types: Optional[list[str]] = None
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
returns a list of tagged objects filtered by tag names and object types
|
|
if no filters applied returns all tagged objects
|
|
"""
|
|
# id = fields.Int()
|
|
# type = fields.String()
|
|
# name = fields.String()
|
|
# url = fields.String()
|
|
# changed_on = fields.DateTime()
|
|
# created_by = fields.Nested(UserSchema)
|
|
# creator = fields.String(
|
|
|
|
# filter types
|
|
|
|
results: list[dict[str, Any]] = []
|
|
|
|
# dashboards
|
|
if (not obj_types) or ("dashboard" in obj_types):
|
|
dashboards = (
|
|
db.session.query(Dashboard)
|
|
.join(
|
|
TaggedObject,
|
|
and_(
|
|
TaggedObject.object_id == Dashboard.id,
|
|
TaggedObject.object_type == ObjectTypes.dashboard,
|
|
),
|
|
)
|
|
.join(Tag, TaggedObject.tag_id == Tag.id)
|
|
.filter(not tags or Tag.name.in_(tags))
|
|
)
|
|
|
|
results.extend(
|
|
{
|
|
"id": obj.id,
|
|
"type": ObjectTypes.dashboard.name,
|
|
"name": obj.dashboard_title,
|
|
"url": obj.url,
|
|
"changed_on": obj.changed_on,
|
|
"created_by": obj.created_by_fk,
|
|
"creator": obj.creator(),
|
|
}
|
|
for obj in dashboards
|
|
)
|
|
|
|
# charts
|
|
if (not obj_types) or ("chart" in obj_types):
|
|
charts = (
|
|
db.session.query(Slice)
|
|
.join(
|
|
TaggedObject,
|
|
and_(
|
|
TaggedObject.object_id == Slice.id,
|
|
TaggedObject.object_type == ObjectTypes.chart,
|
|
),
|
|
)
|
|
.join(Tag, TaggedObject.tag_id == Tag.id)
|
|
.filter(not tags or Tag.name.in_(tags))
|
|
)
|
|
results.extend(
|
|
{
|
|
"id": obj.id,
|
|
"type": ObjectTypes.chart.name,
|
|
"name": obj.slice_name,
|
|
"url": obj.url,
|
|
"changed_on": obj.changed_on,
|
|
"created_by": obj.created_by_fk,
|
|
"creator": obj.creator(),
|
|
}
|
|
for obj in charts
|
|
)
|
|
|
|
# saved queries
|
|
if (not obj_types) or ("query" in obj_types):
|
|
saved_queries = (
|
|
db.session.query(SavedQuery)
|
|
.join(
|
|
TaggedObject,
|
|
and_(
|
|
TaggedObject.object_id == SavedQuery.id,
|
|
TaggedObject.object_type == ObjectTypes.query,
|
|
),
|
|
)
|
|
.join(Tag, TaggedObject.tag_id == Tag.id)
|
|
.filter(not tags or Tag.name.in_(tags))
|
|
)
|
|
results.extend(
|
|
{
|
|
"id": obj.id,
|
|
"type": ObjectTypes.query.name,
|
|
"name": obj.label,
|
|
"url": obj.url(),
|
|
"changed_on": obj.changed_on,
|
|
"created_by": obj.created_by_fk,
|
|
"creator": obj.creator(),
|
|
}
|
|
for obj in saved_queries
|
|
)
|
|
return results
|
|
|
|
@staticmethod
|
|
def favorite_tag_by_id_for_current_user( # pylint: disable=invalid-name
|
|
tag_id: int,
|
|
) -> None:
|
|
"""
|
|
Marks a specific tag as a favorite for the current user.
|
|
This function will find the tag by the provided id,
|
|
create a new UserFavoriteTag object that represents
|
|
the user's preference, add that object to the database
|
|
session, and commit the session. It uses the currently
|
|
authenticated user from the global 'g' object.
|
|
Args:
|
|
tag_id: The id of the tag that is to be marked as
|
|
favorite.
|
|
Raises:
|
|
Any exceptions raised by the find_by_id function,
|
|
the UserFavoriteTag constructor, or the database session's
|
|
add and commit methods will propagate up to the caller.
|
|
Returns:
|
|
None.
|
|
"""
|
|
tag = TagDAO.find_by_id(tag_id)
|
|
user = g.user
|
|
|
|
if not user:
|
|
raise MissingUserContextException(message="User doesn't exist")
|
|
if not tag:
|
|
raise TagNotFoundError()
|
|
|
|
tag.users_favorited.append(user)
|
|
db.session.commit()
|
|
|
|
@staticmethod
|
|
def remove_user_favorite_tag(tag_id: int) -> None:
|
|
"""
|
|
Removes a tag from the current user's favorite tags.
|
|
|
|
This function will find the tag by the provided id and remove the tag
|
|
from the user's list of favorite tags. It uses the currently authenticated
|
|
user from the global 'g' object.
|
|
|
|
Args:
|
|
tag_id: The id of the tag that is to be removed from the favorite tags.
|
|
|
|
Raises:
|
|
Any exceptions raised by the find_by_id function, the database session's
|
|
commit method will propagate up to the caller.
|
|
|
|
Returns:
|
|
None.
|
|
"""
|
|
tag = TagDAO.find_by_id(tag_id)
|
|
user = g.user
|
|
|
|
if not user:
|
|
raise MissingUserContextException(message="User doesn't exist")
|
|
if not tag:
|
|
raise TagNotFoundError()
|
|
|
|
tag.users_favorited.remove(user)
|
|
|
|
# Commit to save the changes
|
|
db.session.commit()
|
|
|
|
@staticmethod
|
|
def favorited_ids(tags: list[Tag]) -> list[int]:
|
|
"""
|
|
Returns the IDs of tags that the current user has favorited.
|
|
|
|
This function takes in a list of Tag objects, extracts their IDs, and checks
|
|
which of these IDs exist in the user_favorite_tag_table for the current user.
|
|
The function returns a list of these favorited tag IDs.
|
|
|
|
Args:
|
|
tags (list[Tag]): A list of Tag objects.
|
|
|
|
Returns:
|
|
list[Any]: A list of IDs corresponding to the tags that are favorited by
|
|
the current user.
|
|
|
|
Example:
|
|
favorited_ids([tag1, tag2, tag3])
|
|
Output: [tag_id1, tag_id3] # if the current user has favorited tag1 and tag3
|
|
"""
|
|
ids = [tag.id for tag in tags]
|
|
return [
|
|
star.tag_id
|
|
for star in db.session.query(user_favorite_tag_table.c.tag_id)
|
|
.filter(
|
|
user_favorite_tag_table.c.tag_id.in_(ids),
|
|
user_favorite_tag_table.c.user_id == get_user_id(),
|
|
)
|
|
.all()
|
|
]
|