feat: use seperate controllers
parent
ed2993cf5f
commit
43361f75e7
@ -0,0 +1,9 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("service_api_with_auth", __name__, url_prefix="/service")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from .auth import login
|
||||
from .user import profile
|
||||
@ -0,0 +1,57 @@
|
||||
from flask_restful import Resource, marshal_with # type: ignore
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.common import helpers as controller_helpers
|
||||
from controllers.service_api_with_auth import api
|
||||
from controllers.service_api_with_auth.app.error import AppUnavailableError
|
||||
from controllers.service_api_with_auth.wraps import validate_app_token
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return controller_helpers.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict, user_input_form=user_input_form
|
||||
)
|
||||
|
||||
|
||||
class AppMetaApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app meta"""
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
class AppInfoApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app information"""
|
||||
tags = [tag.name for tag in app_model.tags]
|
||||
return {"name": app_model.name, "description": app_model.description, "tags": tags}
|
||||
|
||||
|
||||
api.add_resource(AppParameterApi, "/parameters")
|
||||
api.add_resource(AppMetaApi, "/meta")
|
||||
api.add_resource(AppInfoApi, "/info")
|
||||
@ -0,0 +1,125 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.service_api_with_auth import api
|
||||
from controllers.service_api_with_auth.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.service_api_with_auth.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
class AudioApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(AudioApi, "/audio-to-text")
|
||||
api.add_resource(TextApi, "/text-to-audio")
|
||||
@ -0,0 +1,158 @@
|
||||
import logging
|
||||
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api_with_auth import api
|
||||
from controllers.service_api_with_auth.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NotChatAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api_with_auth.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
class CompletionApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != "completion":
|
||||
raise AppUnavailableError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class CompletionStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != "completion":
|
||||
raise AppUnavailableError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class ChatApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class ChatStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(CompletionApi, "/completion-messages")
|
||||
api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop")
|
||||
api.add_resource(ChatApi, "/chat-messages")
|
||||
api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop")
|
||||
@ -0,0 +1,98 @@
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from flask_restful.inputs import int_range # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api_with_auth import api
|
||||
from controllers.service_api_with_auth.app.error import NotChatAppError
|
||||
from controllers.service_api_with_auth.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
conversation_delete_fields,
|
||||
conversation_infinite_scroll_pagination_fields,
|
||||
simple_conversation_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class ConversationApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
class ConversationDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(conversation_delete_fields)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class ConversationRenameApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, location="json")
|
||||
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name")
|
||||
api.add_resource(ConversationApi, "/conversations")
|
||||
api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail")
|
||||
@ -0,0 +1,109 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class AppUnavailableError(BaseHTTPException):
|
||||
error_code = "app_unavailable"
|
||||
description = "App unavailable, please check your app configurations."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotCompletionAppError(BaseHTTPException):
|
||||
error_code = "not_completion_app"
|
||||
description = "Please check if your Completion app mode matches the right API route."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotChatAppError(BaseHTTPException):
|
||||
error_code = "not_chat_app"
|
||||
description = "Please check if your app mode matches the right API route."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotWorkflowAppError(BaseHTTPException):
|
||||
error_code = "not_workflow_app"
|
||||
description = "Please check if your app mode matches the right API route."
|
||||
code = 400
|
||||
|
||||
|
||||
class ConversationCompletedError(BaseHTTPException):
|
||||
error_code = "conversation_completed"
|
||||
description = "The conversation has ended. Please start a new conversation."
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderNotInitializeError(BaseHTTPException):
|
||||
error_code = "provider_not_initialize"
|
||||
description = (
|
||||
"No valid model provider credentials found. "
|
||||
"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderQuotaExceededError(BaseHTTPException):
|
||||
error_code = "provider_quota_exceeded"
|
||||
description = (
|
||||
"Your quota for Dify Hosted OpenAI has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
)
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
|
||||
error_code = "model_currently_not_support"
|
||||
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
|
||||
code = 400
|
||||
|
||||
|
||||
class CompletionRequestError(BaseHTTPException):
|
||||
error_code = "completion_request_error"
|
||||
description = "Completion request failed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoAudioUploadedError(BaseHTTPException):
|
||||
error_code = "no_audio_uploaded"
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
|
||||
|
||||
class AudioTooLargeError(BaseHTTPException):
|
||||
error_code = "audio_too_large"
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_audio_type"
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = "provider_not_support_speech_to_text"
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
@ -0,0 +1,53 @@
|
||||
from flask import request
|
||||
from flask_restful import Resource, marshal_with # type: ignore
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.service_api_with_auth import api
|
||||
from controllers.service_api_with_auth.app.error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api_with_auth.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.file_fields import file_fields
|
||||
from models.model import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class FileApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
file = request.files["file"]
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=end_user,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
@ -0,0 +1,149 @@
|
||||
import logging
|
||||
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
|
||||
from flask_restful.inputs import int_range # type: ignore
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api_with_auth import api
|
||||
from controllers.service_api_with_auth.app.error import NotChatAppError
|
||||
from controllers.service_api_with_auth.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(Resource):
|
||||
feedback_fields = {"rating": fields.String}
|
||||
retriever_resource_fields = {
|
||||
"id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"dataset_id": fields.String,
|
||||
"dataset_name": fields.String,
|
||||
"document_id": fields.String,
|
||||
"document_name": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"segment_id": fields.String,
|
||||
"score": fields.Float,
|
||||
"hit_count": fields.Integer,
|
||||
"word_count": fields.Integer,
|
||||
"segment_position": fields.Integer,
|
||||
"index_node_hash": fields.String,
|
||||
"content": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class MessageFeedbackApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
parser.add_argument("content", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=args.get("rating"),
|
||||
content=args.get("content"),
|
||||
)
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class MessageSuggestedApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise BadRequest("Suggested Questions Is Disabled.")
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"result": "success", "data": questions}
|
||||
|
||||
|
||||
api.add_resource(MessageListApi, "/messages")
|
||||
api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
|
||||
api.add_resource(MessageSuggestedApi, "/messages/<uuid:message_id>/suggested")
|
||||
@ -0,0 +1,144 @@
|
||||
import logging
|
||||
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
|
||||
from flask_restful.inputs import int_range # type: ignore
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api_with_auth import api
|
||||
from controllers.service_api_with_auth.app.error import (
|
||||
CompletionRequestError,
|
||||
NotWorkflowAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api_with_auth.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_id": fields.String,
|
||||
"status": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"outputs": fields.Raw,
|
||||
"error": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"total_tokens": fields.Integer,
|
||||
"created_at": fields.DateTime,
|
||||
"finished_at": fields.DateTime,
|
||||
"elapsed_time": fields.Float,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowRunDetailApi(Resource):
|
||||
@validate_app_token
|
||||
@marshal_with(workflow_run_fields)
|
||||
def get(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Get a workflow task running detail
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
|
||||
return workflow_run
|
||||
|
||||
|
||||
class WorkflowRunApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@validate_app_token
|
||||
@marshal_with(workflow_app_log_pagination_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
app_model=app_model, args=args
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
|
||||
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
|
||||
api.add_resource(WorkflowAppLogApi, "/workflows/logs")
|
||||
@ -0,0 +1,103 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class AlreadySetupError(BaseHTTPException):
|
||||
error_code = "already_setup"
|
||||
description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage."
|
||||
code = 403
|
||||
|
||||
|
||||
class NotSetupError(BaseHTTPException):
|
||||
error_code = "not_setup"
|
||||
description = (
|
||||
"Dify has not been initialized and installed yet. "
|
||||
"Please proceed with the initialization and installation process first."
|
||||
)
|
||||
code = 401
|
||||
|
||||
|
||||
class NotInitValidateError(BaseHTTPException):
|
||||
error_code = "not_init_validated"
|
||||
description = "Init validation has not been completed yet. Please proceed with the init validation process first."
|
||||
code = 401
|
||||
|
||||
|
||||
class InitValidateFailedError(BaseHTTPException):
|
||||
error_code = "init_validate_failed"
|
||||
description = "Init validation failed. Please check the password and try again."
|
||||
code = 401
|
||||
|
||||
|
||||
class AccountNotLinkTenantError(BaseHTTPException):
|
||||
error_code = "account_not_link_tenant"
|
||||
description = "Account not link tenant."
|
||||
code = 403
|
||||
|
||||
|
||||
class AlreadyActivateError(BaseHTTPException):
|
||||
error_code = "already_activate"
|
||||
description = "Auth Token is invalid or account already activated, please check again."
|
||||
code = 403
|
||||
|
||||
|
||||
class NotAllowedCreateWorkspace(BaseHTTPException):
|
||||
error_code = "not_allowed_create_workspace"
|
||||
description = "Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
code = 400
|
||||
|
||||
|
||||
class AccountBannedError(BaseHTTPException):
|
||||
error_code = "account_banned"
|
||||
description = "Account is banned."
|
||||
code = 400
|
||||
|
||||
|
||||
class AccountNotFound(BaseHTTPException):
|
||||
error_code = "account_not_found"
|
||||
description = "Account not found."
|
||||
code = 400
|
||||
|
||||
|
||||
class EmailSendIpLimitError(BaseHTTPException):
|
||||
error_code = "email_send_ip_limit"
|
||||
description = "Too many emails have been sent from this IP address recently. Please try again later."
|
||||
code = 429
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = "file_too_large"
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = "no_file_uploaded"
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class UnauthorizedAndForceLogout(BaseHTTPException):
|
||||
error_code = "unauthorized_and_force_logout"
|
||||
description = "Unauthorized and force logout."
|
||||
code = 401
|
||||
|
||||
|
||||
class AccountInFreezeError(BaseHTTPException):
|
||||
error_code = "account_in_freeze"
|
||||
code = 400
|
||||
description = (
|
||||
"This email account has been deleted within the past 30 days"
|
||||
"and is temporarily unavailable for new account registration."
|
||||
)
|
||||
@ -1,7 +1,7 @@
|
||||
from flask import Blueprint
|
||||
from flask_restful import Api, Resource # type: ignore
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api_with_auth import api
|
||||
|
||||
class UserProfile(Resource):
|
||||
def get(self):
|
||||
@ -0,0 +1,254 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in # type: ignore
|
||||
from flask_restful import Resource # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class WhereisUserArg(Enum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
|
||||
QUERY = "query"
|
||||
JSON = "json"
|
||||
FORM = "form"
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
fetch_from: WhereisUserArg
|
||||
required: bool = False
|
||||
|
||||
|
||||
# TODO: add auth jwt token check
|
||||
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token("app")
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||
if not app_model:
|
||||
raise Forbidden("The app no longer exists.")
|
||||
|
||||
if app_model.status != "normal":
|
||||
raise Forbidden("The app's status is abnormal.")
|
||||
|
||||
if not app_model.enable_api:
|
||||
raise Forbidden("The app's API service has been disabled.")
|
||||
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
|
||||
if tenant is None:
|
||||
raise ValueError("Tenant does not exist.")
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("The workspace's status is archived.")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get("user")
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
else:
|
||||
# use default-user
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
raise ValueError("Arg user must be provided.")
|
||||
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
|
||||
kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
raise Forbidden("The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
raise Forbidden("The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Forbidden("The number of documents has reached the limit of your subscription.")
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
raise Forbidden(
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
|
||||
)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
|
||||
def validate_dataset_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.filter(Tenant.id == api_token.tenant_id)
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.role.in_(["owner"]))
|
||||
.filter(Tenant.status == TenantStatus.NORMAL)
|
||||
.one_or_none()
|
||||
) # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
"""
|
||||
Validate and get API token.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None or " " not in auth_header:
|
||||
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
current_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = result.scalar_one_or_none()
|
||||
|
||||
if not api_token:
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
return api_token
|
||||
|
||||
|
||||
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
end_user = (
|
||||
db.session.query(EndUser)
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
method_decorators = [validate_dataset_token]
|
||||
Loading…
Reference in New Issue