diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 394f36c3ff..053d8accaa 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,19 +1,21 @@ from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound import services.dataset_service from controllers.service_api import api from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError -from controllers.service_api.wraps import DatasetApiResource +from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields +from fields.tag_fields import tag_fields from libs.login import current_user from models.dataset import Dataset, DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel +from services.tag_service import TagService def _validate_name(name): @@ -320,5 +322,133 @@ class DatasetApi(DatasetApiResource): raise DatasetInUseError() +class DatasetTagsApi(DatasetApiResource): + @validate_dataset_token + @marshal_with(tag_fields) + def get(self, _, dataset_id): + """Get all knowledge type tags.""" + tags = TagService.get_tags("knowledge", current_user.current_tenant_id) + + return tags, 200 + + @validate_dataset_token + def post(self, _, dataset_id): + """Add a knowledge type tag.""" + if not (current_user.is_editor or current_user.is_dataset_editor): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", + type=_validate_name + ) + + args = parser.parse_args() + args["type"] = "knowledge" + tag = TagService.save_tags(args) + + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + + return response, 200 + + @validate_dataset_token + def patch(self, _, dataset_id): + if not (current_user.is_editor or current_user.is_dataset_editor): + raise Forbidden() + def _validate_tag_name(name): + if not name or len(name) < 1 or len(name) > 50: + raise ValueError("Name must be between 1 to 50 characters.") + return name + + parser = reqparse.RequestParser() + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", + type=_validate_tag_name + ) + parser.add_argument( + "tag_id", nullable=False, required=True, help="Id of a tag.", type=str + ) + args = parser.parse_args() + tag = TagService.update_tags(args, args.get("tag_id")) + + binding_count = TagService.get_tag_binding_count(args.get("tag_id")) + + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + + return response, 200 + + @validate_dataset_token + def delete(self, _, dataset_id): + """Delete a knowledge type tag.""" + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument( + "tag_id", nullable=False, required=True, help="Id of a tag.", type=str + ) + args = parser.parse_args() + TagService.delete_tag(args.get("tag_id")) + + return {"result": "success"}, 200 + + +class DatasetTagBindingApi(DatasetApiResource): + + @validate_dataset_token + def post(self, _, dataset_id): + # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + if not (current_user.is_editor or current_user.is_dataset_editor): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." + ) + parser.add_argument( + "target_id", type=str, nullable=False, required=True, location="json", + help="Target Dataset ID is required." + ) + + args = parser.parse_args() + args["type"] = "knowledge" + TagService.save_tag_binding(args) + + return {"result": "success"}, 200 + + +class DatasetTagUnbindingApi(DatasetApiResource): + + @validate_dataset_token + def post(self, _, dataset_id): + # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + if not (current_user.is_editor or current_user.is_dataset_editor): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") + parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + + args = parser.parse_args() + args["type"] = "knowledge" + TagService.delete_tag_binding(args) + + return {"result": "success"}, 200 + + +class DatasetTagsBindingStatusApi(DatasetApiResource): + + @validate_dataset_token + def get(self, _, *args, **kwargs): + """Get all knowledge type tags.""" + dataset_id = kwargs.get("dataset_id") + tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id)) + tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] + response = {"data": tags_list, "total": len(tags)} + return response, 200 + api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetApi, "/datasets/") +api.add_resource(DatasetTagsApi, "/datasets/tags") +api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding") +api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding") +api.add_resource(DatasetTagsBindingStatusApi, "/datasets//tags")