feat: universal chat in explore (#649)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>pull/654/head
parent
94b54b7ca9
commit
4fdb37771a
@ -0,0 +1,66 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \
|
||||||
|
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \
|
||||||
|
NoAudioUploadedError, AudioTooLargeError, \
|
||||||
|
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||||
|
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||||
|
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||||
|
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
|
from services.audio_service import AudioService
|
||||||
|
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
|
||||||
|
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
|
||||||
|
from models.model import AppModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatAudioApi(UniversalChatResource):
|
||||||
|
def post(self, universal_app):
|
||||||
|
app_model = universal_app
|
||||||
|
app_model_config: AppModelConfig = app_model.app_model_config
|
||||||
|
|
||||||
|
if not app_model_config.speech_to_text_dict['enabled']:
|
||||||
|
raise AppUnavailableError()
|
||||||
|
|
||||||
|
file = request.files['file']
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = AudioService.transcript(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
file=file,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
logging.exception("App model config broken.")
|
||||||
|
raise AppUnavailableError()
|
||||||
|
except NoAudioUploadedServiceError:
|
||||||
|
raise NoAudioUploadedError()
|
||||||
|
except AudioTooLargeServiceError as e:
|
||||||
|
raise AudioTooLargeError(str(e))
|
||||||
|
except UnsupportedAudioTypeServiceError:
|
||||||
|
raise UnsupportedAudioTypeError()
|
||||||
|
except ProviderNotSupportSpeechToTextServiceError:
|
||||||
|
raise ProviderNotSupportSpeechToTextError()
|
||||||
|
except ProviderTokenNotInitError:
|
||||||
|
raise ProviderNotInitializeError()
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||||
|
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||||
|
raise CompletionRequestError(str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(UniversalChatAudioApi, '/universal-chat/audio-to-text')
|
||||||
@ -0,0 +1,127 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Generator, Union
|
||||||
|
|
||||||
|
from flask import Response, stream_with_context
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_restful import reqparse
|
||||||
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
|
||||||
|
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||||
|
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||||
|
from core.constant import llm_constant
|
||||||
|
from core.conversation_message_task import PubHandler
|
||||||
|
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||||
|
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
|
||||||
|
from libs.helper import uuid_value
|
||||||
|
from services.completion_service import CompletionService
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatApi(UniversalChatResource):
|
||||||
|
def post(self, universal_app):
|
||||||
|
app_model = universal_app
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('query', type=str, required=True, location='json')
|
||||||
|
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||||
|
parser.add_argument('model', type=str, required=True, location='json')
|
||||||
|
parser.add_argument('tools', type=list, required=True, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
app_model_config = app_model.app_model_config
|
||||||
|
|
||||||
|
# update app model config
|
||||||
|
args['model_config'] = app_model_config.to_dict()
|
||||||
|
args['model_config']['model']['name'] = args['model']
|
||||||
|
|
||||||
|
if not llm_constant.models[args['model']]:
|
||||||
|
raise ValueError("Model not exists.")
|
||||||
|
|
||||||
|
args['model_config']['model']['provider'] = llm_constant.models[args['model']]
|
||||||
|
args['model_config']['agent_mode']['tools'] = args['tools']
|
||||||
|
|
||||||
|
args['inputs'] = {}
|
||||||
|
|
||||||
|
del args['model']
|
||||||
|
del args['tools']
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = CompletionService.completion(
|
||||||
|
app_model=app_model,
|
||||||
|
user=current_user,
|
||||||
|
args=args,
|
||||||
|
from_source='console',
|
||||||
|
streaming=True,
|
||||||
|
is_model_config_override=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return compact_response(response)
|
||||||
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
except services.errors.conversation.ConversationCompletedError:
|
||||||
|
raise ConversationCompletedError()
|
||||||
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
logging.exception("App model config broken.")
|
||||||
|
raise AppUnavailableError()
|
||||||
|
except ProviderTokenNotInitError:
|
||||||
|
raise ProviderNotInitializeError()
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||||
|
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||||
|
raise CompletionRequestError(str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatStopApi(UniversalChatResource):
|
||||||
|
def post(self, universal_app, task_id):
|
||||||
|
PubHandler.stop(current_user, task_id)
|
||||||
|
|
||||||
|
return {'result': 'success'}, 200
|
||||||
|
|
||||||
|
|
||||||
|
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||||
|
if isinstance(response, dict):
|
||||||
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||||
|
else:
|
||||||
|
def generate() -> Generator:
|
||||||
|
try:
|
||||||
|
for chunk in response:
|
||||||
|
yield chunk
|
||||||
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
|
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
|
||||||
|
except services.errors.conversation.ConversationCompletedError:
|
||||||
|
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
|
||||||
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
logging.exception("App model config broken.")
|
||||||
|
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||||
|
except ProviderTokenNotInitError:
|
||||||
|
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
||||||
|
except QuotaExceededError:
|
||||||
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||||
|
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||||
|
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||||
|
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||||
|
except ValueError as e:
|
||||||
|
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||||
|
except Exception:
|
||||||
|
logging.exception("internal server error.")
|
||||||
|
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
|
||||||
|
|
||||||
|
return Response(stream_with_context(generate()), status=200,
|
||||||
|
mimetype='text/event-stream')
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(UniversalChatApi, '/universal-chat/messages')
|
||||||
|
api.add_resource(UniversalChatStopApi, '/universal-chat/messages/<string:task_id>/stop')
|
||||||
@ -0,0 +1,118 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_restful import fields, reqparse, marshal_with
|
||||||
|
from flask_restful.inputs import int_range
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||||
|
from libs.helper import TimestampField, uuid_value
|
||||||
|
from services.conversation_service import ConversationService
|
||||||
|
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||||
|
from services.web_conversation_service import WebConversationService
|
||||||
|
|
||||||
|
conversation_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'inputs': fields.Raw,
|
||||||
|
'status': fields.String,
|
||||||
|
'introduction': fields.String,
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'model_config': fields.Raw,
|
||||||
|
}
|
||||||
|
|
||||||
|
conversation_infinite_scroll_pagination_fields = {
|
||||||
|
'limit': fields.Integer,
|
||||||
|
'has_more': fields.Boolean,
|
||||||
|
'data': fields.List(fields.Nested(conversation_fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatConversationListApi(UniversalChatResource):
|
||||||
|
|
||||||
|
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||||
|
def get(self, universal_app):
|
||||||
|
app_model = universal_app
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||||
|
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||||
|
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
pinned = None
|
||||||
|
if 'pinned' in args and args['pinned'] is not None:
|
||||||
|
pinned = True if args['pinned'] == 'true' else False
|
||||||
|
|
||||||
|
try:
|
||||||
|
return WebConversationService.pagination_by_last_id(
|
||||||
|
app_model=app_model,
|
||||||
|
user=current_user,
|
||||||
|
last_id=args['last_id'],
|
||||||
|
limit=args['limit'],
|
||||||
|
pinned=pinned
|
||||||
|
)
|
||||||
|
except LastConversationNotExistsError:
|
||||||
|
raise NotFound("Last Conversation Not Exists.")
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatConversationApi(UniversalChatResource):
|
||||||
|
def delete(self, universal_app, c_id):
|
||||||
|
app_model = universal_app
|
||||||
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
|
except ConversationNotExistsError:
|
||||||
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
|
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||||
|
|
||||||
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||||
|
|
||||||
|
@marshal_with(conversation_fields)
|
||||||
|
def post(self, universal_app, c_id):
|
||||||
|
app_model = universal_app
|
||||||
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('name', type=str, required=True, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
|
||||||
|
except ConversationNotExistsError:
|
||||||
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatConversationPinApi(UniversalChatResource):
|
||||||
|
|
||||||
|
def patch(self, universal_app, c_id):
|
||||||
|
app_model = universal_app
|
||||||
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||||
|
except ConversationNotExistsError:
|
||||||
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatConversationUnPinApi(UniversalChatResource):
|
||||||
|
def patch(self, universal_app, c_id):
|
||||||
|
app_model = universal_app
|
||||||
|
conversation_id = str(c_id)
|
||||||
|
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||||
|
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(UniversalChatConversationRenameApi, '/universal-chat/conversations/<uuid:c_id>/name')
|
||||||
|
api.add_resource(UniversalChatConversationListApi, '/universal-chat/conversations')
|
||||||
|
api.add_resource(UniversalChatConversationApi, '/universal-chat/conversations/<uuid:c_id>')
|
||||||
|
api.add_resource(UniversalChatConversationPinApi, '/universal-chat/conversations/<uuid:c_id>/pin')
|
||||||
|
api.add_resource(UniversalChatConversationUnPinApi, '/universal-chat/conversations/<uuid:c_id>/unpin')
|
||||||
@ -0,0 +1,127 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_restful import reqparse, fields, marshal_with
|
||||||
|
from flask_restful.inputs import int_range
|
||||||
|
from werkzeug.exceptions import NotFound, InternalServerError
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.app.error import ProviderNotInitializeError, \
|
||||||
|
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||||
|
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||||
|
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||||
|
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||||
|
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||||
|
from libs.helper import uuid_value, TimestampField
|
||||||
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
|
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||||
|
from services.message_service import MessageService
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatMessageListApi(UniversalChatResource):
|
||||||
|
feedback_fields = {
|
||||||
|
'rating': fields.String
|
||||||
|
}
|
||||||
|
|
||||||
|
agent_thought_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'chain_id': fields.String,
|
||||||
|
'message_id': fields.String,
|
||||||
|
'position': fields.Integer,
|
||||||
|
'thought': fields.String,
|
||||||
|
'tool': fields.String,
|
||||||
|
'tool_input': fields.String,
|
||||||
|
'created_at': TimestampField
|
||||||
|
}
|
||||||
|
|
||||||
|
message_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'conversation_id': fields.String,
|
||||||
|
'inputs': fields.Raw,
|
||||||
|
'query': fields.String,
|
||||||
|
'answer': fields.String,
|
||||||
|
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||||
|
'created_at': TimestampField,
|
||||||
|
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
message_infinite_scroll_pagination_fields = {
|
||||||
|
'limit': fields.Integer,
|
||||||
|
'has_more': fields.Boolean,
|
||||||
|
'data': fields.List(fields.Nested(message_fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
|
def get(self, universal_app):
|
||||||
|
app_model = universal_app
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||||
|
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||||
|
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return MessageService.pagination_by_first_id(app_model, current_user,
|
||||||
|
args['conversation_id'], args['first_id'], args['limit'])
|
||||||
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
except services.errors.message.FirstMessageNotExistsError:
|
||||||
|
raise NotFound("First Message Not Exists.")
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatMessageFeedbackApi(UniversalChatResource):
|
||||||
|
def post(self, universal_app, message_id):
|
||||||
|
app_model = universal_app
|
||||||
|
message_id = str(message_id)
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
|
||||||
|
except services.errors.message.MessageNotExistsError:
|
||||||
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
||||||
|
return {'result': 'success'}
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
||||||
|
def get(self, universal_app, message_id):
|
||||||
|
app_model = universal_app
|
||||||
|
message_id = str(message_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
questions = MessageService.get_suggested_questions_after_answer(
|
||||||
|
app_model=app_model,
|
||||||
|
user=current_user,
|
||||||
|
message_id=message_id
|
||||||
|
)
|
||||||
|
except MessageNotExistsError:
|
||||||
|
raise NotFound("Message not found")
|
||||||
|
except ConversationNotExistsError:
|
||||||
|
raise NotFound("Conversation not found")
|
||||||
|
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||||
|
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||||
|
except ProviderTokenNotInitError:
|
||||||
|
raise ProviderNotInitializeError()
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||||
|
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||||
|
raise CompletionRequestError(str(e))
|
||||||
|
except Exception:
|
||||||
|
logging.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
return {'data': questions}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(UniversalChatMessageListApi, '/universal-chat/messages')
|
||||||
|
api.add_resource(UniversalChatMessageFeedbackApi, '/universal-chat/messages/<uuid:message_id>/feedbacks')
|
||||||
|
api.add_resource(UniversalChatMessageSuggestedQuestionApi, '/universal-chat/messages/<uuid:message_id>/suggested-questions')
|
||||||
@ -0,0 +1,36 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
from flask_restful import marshal_with, fields
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||||
|
|
||||||
|
from core.llm.llm_builder import LLMBuilder
|
||||||
|
from models.provider import ProviderName
|
||||||
|
from models.model import App
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatParameterApi(UniversalChatResource):
|
||||||
|
"""Resource for app variables."""
|
||||||
|
parameters_fields = {
|
||||||
|
'opening_statement': fields.String,
|
||||||
|
'suggested_questions': fields.Raw,
|
||||||
|
'suggested_questions_after_answer': fields.Raw,
|
||||||
|
'speech_to_text': fields.Raw,
|
||||||
|
}
|
||||||
|
|
||||||
|
@marshal_with(parameters_fields)
|
||||||
|
def get(self, universal_app: App):
|
||||||
|
"""Retrieve app parameters."""
|
||||||
|
app_model = universal_app
|
||||||
|
app_model_config = app_model.app_model_config
|
||||||
|
provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
|
||||||
|
|
||||||
|
return {
|
||||||
|
'opening_statement': app_model_config.opening_statement,
|
||||||
|
'suggested_questions': app_model_config.suggested_questions_list,
|
||||||
|
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||||
|
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(UniversalChatParameterApi, '/universal-chat/parameters')
|
||||||
@ -0,0 +1,84 @@
|
|||||||
|
import json
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from flask_login import login_required, current_user
|
||||||
|
from flask_restful import Resource
|
||||||
|
from controllers.console.setup import setup_required
|
||||||
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import App, AppModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
def universal_chat_app_required(view=None):
|
||||||
|
def decorator(view):
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
# get universal chat app
|
||||||
|
universal_app = db.session.query(App).filter(
|
||||||
|
App.tenant_id == current_user.current_tenant_id,
|
||||||
|
App.is_universal == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if universal_app is None:
|
||||||
|
# create universal app if not exists
|
||||||
|
universal_app = App(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
name='Universal Chat',
|
||||||
|
mode='chat',
|
||||||
|
is_universal=True,
|
||||||
|
icon='',
|
||||||
|
icon_background='',
|
||||||
|
api_rpm=0,
|
||||||
|
api_rph=0,
|
||||||
|
enable_site=False,
|
||||||
|
enable_api=False,
|
||||||
|
status='normal'
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(universal_app)
|
||||||
|
db.session.flush()
|
||||||
|
|
||||||
|
app_model_config = AppModelConfig(
|
||||||
|
provider="",
|
||||||
|
model_id="",
|
||||||
|
configs={},
|
||||||
|
opening_statement='',
|
||||||
|
suggested_questions=json.dumps([]),
|
||||||
|
suggested_questions_after_answer=json.dumps({'enabled': True}),
|
||||||
|
speech_to_text=json.dumps({'enabled': True}),
|
||||||
|
more_like_this=None,
|
||||||
|
sensitive_word_avoidance=None,
|
||||||
|
model=json.dumps({
|
||||||
|
"provider": "openai",
|
||||||
|
"name": "gpt-3.5-turbo-16k",
|
||||||
|
"completion_params": {
|
||||||
|
"max_tokens": 800,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
"frequency_penalty": 0
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
user_input_form=json.dumps([]),
|
||||||
|
pre_prompt='',
|
||||||
|
agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}),
|
||||||
|
)
|
||||||
|
|
||||||
|
app_model_config.app_id = universal_app.id
|
||||||
|
db.session.add(app_model_config)
|
||||||
|
db.session.flush()
|
||||||
|
|
||||||
|
universal_app.app_model_config_id = app_model_config.id
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return view(universal_app, *args, **kwargs)
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
if view:
|
||||||
|
return decorator(view)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalChatResource(Resource):
|
||||||
|
# must be reversed if there are multiple decorators
|
||||||
|
method_decorators = [universal_chat_app_required, account_initialization_required, login_required, setup_required]
|
||||||
@ -0,0 +1,136 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from flask_login import login_required, current_user
|
||||||
|
from flask_restful import Resource, abort, reqparse
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.setup import setup_required
|
||||||
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from core.tool.provider.errors import ToolValidateFailedError
|
||||||
|
from core.tool.provider.tool_provider_service import ToolProviderService
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.tool import ToolProvider, ToolProviderName
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProviderListApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
tool_credential_dict = {}
|
||||||
|
for tool_name in ToolProviderName:
|
||||||
|
tool_credential_dict[tool_name.value] = {
|
||||||
|
'tool_name': tool_name.value,
|
||||||
|
'is_enabled': False,
|
||||||
|
'credentials': None
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
|
||||||
|
|
||||||
|
for p in tool_providers:
|
||||||
|
if p.is_enabled:
|
||||||
|
tool_credential_dict[p.tool_name] = {
|
||||||
|
'tool_name': p.tool_name,
|
||||||
|
'is_enabled': p.is_enabled,
|
||||||
|
'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
|
||||||
|
}
|
||||||
|
|
||||||
|
return list(tool_credential_dict.values())
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProviderCredentialsApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider):
|
||||||
|
if provider not in [p.value for p in ToolProviderName]:
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||||
|
raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
|
||||||
|
f'current_role is {current_user.current_tenant.current_role}')
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_provider_service.credentials_validate(args['credentials'])
|
||||||
|
except ToolValidateFailedError as ex:
|
||||||
|
raise ValueError(str(ex))
|
||||||
|
|
||||||
|
encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
|
||||||
|
|
||||||
|
tenant = current_user.current_tenant
|
||||||
|
|
||||||
|
tool_provider_model = db.session.query(ToolProvider).filter(
|
||||||
|
ToolProvider.tenant_id == tenant.id,
|
||||||
|
ToolProvider.tool_name == provider,
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# Only allow updating token for CUSTOM provider type
|
||||||
|
if tool_provider_model:
|
||||||
|
tool_provider_model.encrypted_credentials = encrypted_credentials
|
||||||
|
tool_provider_model.is_enabled = True
|
||||||
|
else:
|
||||||
|
tool_provider_model = ToolProvider(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
tool_name=provider,
|
||||||
|
encrypted_credentials=encrypted_credentials,
|
||||||
|
is_enabled=True
|
||||||
|
)
|
||||||
|
db.session.add(tool_provider_model)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return {'result': 'success'}, 201
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProviderCredentialsValidateApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider):
|
||||||
|
if provider not in [p.value for p in ToolProviderName]:
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
result = True
|
||||||
|
error = None
|
||||||
|
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_provider_service.credentials_validate(args['credentials'])
|
||||||
|
except ToolValidateFailedError as ex:
|
||||||
|
result = False
|
||||||
|
error = str(ex)
|
||||||
|
|
||||||
|
response = {'result': 'success' if result else 'error'}
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
response['error'] = error
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
|
||||||
|
api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
|
||||||
|
api.add_resource(ToolProviderCredentialsValidateApi,
|
||||||
|
'/workspaces/current/tool-providers/<provider>/credentials-validate')
|
||||||
@ -0,0 +1,35 @@
|
|||||||
|
from typing import cast, List
|
||||||
|
|
||||||
|
from langchain import OpenAI
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
|
from langchain.schema import BaseMessage
|
||||||
|
|
||||||
|
from core.constant import llm_constant
|
||||||
|
|
||||||
|
|
||||||
|
class CalcTokenMixin:
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||||
|
llm = cast(ChatOpenAI, llm)
|
||||||
|
return llm.get_num_tokens_from_messages(messages)
|
||||||
|
|
||||||
|
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||||
|
"""
|
||||||
|
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||||
|
|
||||||
|
:param llm:
|
||||||
|
:param messages:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
llm = cast(ChatOpenAI, llm)
|
||||||
|
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
|
||||||
|
completion_max_tokens = llm.max_tokens
|
||||||
|
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
|
||||||
|
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
|
||||||
|
|
||||||
|
return rest_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class ExceededLLMTokensLimitError(Exception):
|
||||||
|
pass
|
||||||
@ -0,0 +1,84 @@
|
|||||||
|
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||||
|
|
||||||
|
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||||
|
"""
|
||||||
|
An Multi Dataset Retrieve Agent driven by Router.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def should_use_agent(self, query: str):
|
||||||
|
"""
|
||||||
|
return should use agent
|
||||||
|
|
||||||
|
:param query:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def plan(
|
||||||
|
self,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||||
|
**kwargs: User inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action specifying what tool to use.
|
||||||
|
"""
|
||||||
|
if len(self.tools) == 0:
|
||||||
|
return AgentFinish(return_values={"output": ''}, log='')
|
||||||
|
elif len(self.tools) == 1:
|
||||||
|
tool = next(iter(self.tools))
|
||||||
|
tool = cast(DatasetRetrieverTool, tool)
|
||||||
|
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
|
||||||
|
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||||
|
|
||||||
|
if intermediate_steps:
|
||||||
|
_, observation = intermediate_steps[-1]
|
||||||
|
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||||
|
|
||||||
|
return super().plan(intermediate_steps, callbacks, **kwargs)
|
||||||
|
|
||||||
|
async def aplan(
|
||||||
|
self,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm_and_tools(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||||
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
|
content="You are a helpful AI assistant."
|
||||||
|
),
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseSingleActionAgent:
|
||||||
|
llm.model_name = 'gpt-3.5-turbo'
|
||||||
|
return super().from_llm_and_tools(
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
extra_prompt_messages=extra_prompt_messages,
|
||||||
|
system_message=system_message,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
@ -0,0 +1,120 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||||
|
|
||||||
|
import pytz
|
||||||
|
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||||
|
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
||||||
|
_format_intermediate_steps
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||||
|
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||||
|
|
||||||
|
|
||||||
|
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm_and_tools(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||||
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
|
content="You are a helpful AI assistant."
|
||||||
|
),
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseSingleActionAgent:
|
||||||
|
return super().from_llm_and_tools(
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
extra_prompt_messages=extra_prompt_messages,
|
||||||
|
system_message=cls.get_system_message(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_use_agent(self, query: str):
|
||||||
|
"""
|
||||||
|
return should use agent
|
||||||
|
|
||||||
|
:param query:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
original_max_tokens = self.llm.max_tokens
|
||||||
|
self.llm.max_tokens = 15
|
||||||
|
|
||||||
|
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||||
|
messages = prompt.to_messages()
|
||||||
|
|
||||||
|
predicted_message = self.llm.predict_messages(
|
||||||
|
messages, functions=self.functions, callbacks=None
|
||||||
|
)
|
||||||
|
|
||||||
|
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||||
|
|
||||||
|
self.llm.max_tokens = original_max_tokens
|
||||||
|
|
||||||
|
return True if function_call else False
|
||||||
|
|
||||||
|
def plan(
|
||||||
|
self,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||||
|
**kwargs: User inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action specifying what tool to use.
|
||||||
|
"""
|
||||||
|
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||||
|
selected_inputs = {
|
||||||
|
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||||
|
}
|
||||||
|
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||||
|
prompt = self.prompt.format_prompt(**full_inputs)
|
||||||
|
messages = prompt.to_messages()
|
||||||
|
|
||||||
|
# summarize messages if rest_tokens < 0
|
||||||
|
try:
|
||||||
|
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
|
||||||
|
except ExceededLLMTokensLimitError as e:
|
||||||
|
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||||
|
|
||||||
|
predicted_message = self.llm.predict_messages(
|
||||||
|
messages, functions=self.functions, callbacks=callbacks
|
||||||
|
)
|
||||||
|
agent_decision = _parse_ai_message(predicted_message)
|
||||||
|
return agent_decision
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_system_message(cls):
|
||||||
|
# get current time
|
||||||
|
current_time = datetime.now()
|
||||||
|
current_timezone = pytz.timezone('UTC')
|
||||||
|
current_time = current_timezone.localize(current_time)
|
||||||
|
|
||||||
|
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||||
|
"Current time: {}\n"
|
||||||
|
"Respond directly if appropriate.".format(
|
||||||
|
current_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")))
|
||||||
|
|
||||||
|
def return_stopped_response(
|
||||||
|
self,
|
||||||
|
early_stopping_method: str,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AgentFinish:
|
||||||
|
try:
|
||||||
|
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||||
|
except ValueError:
|
||||||
|
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||||
@ -0,0 +1,132 @@
|
|||||||
|
from typing import cast, List
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.chat_models.openai import _convert_message_to_dict
|
||||||
|
from langchain.memory.summary import SummarizerMixin
|
||||||
|
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||||
|
moving_summary_buffer: str = ""
|
||||||
|
moving_summary_index: int = 0
|
||||||
|
summary_llm: BaseLanguageModel
|
||||||
|
|
||||||
|
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||||
|
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||||
|
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
|
||||||
|
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||||
|
if rest_tokens >= 0:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
system_message = None
|
||||||
|
human_message = None
|
||||||
|
should_summary_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, SystemMessage):
|
||||||
|
system_message = message
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
human_message = message
|
||||||
|
else:
|
||||||
|
should_summary_messages.append(message)
|
||||||
|
|
||||||
|
if len(should_summary_messages) > 2:
|
||||||
|
ai_message = should_summary_messages[-2]
|
||||||
|
function_message = should_summary_messages[-1]
|
||||||
|
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||||
|
self.moving_summary_index = len(should_summary_messages)
|
||||||
|
else:
|
||||||
|
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||||
|
raise ExceededLLMTokensLimitError(error_msg)
|
||||||
|
|
||||||
|
new_messages = [system_message, human_message]
|
||||||
|
|
||||||
|
if self.moving_summary_index == 0:
|
||||||
|
should_summary_messages.insert(0, human_message)
|
||||||
|
|
||||||
|
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||||
|
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||||
|
messages=should_summary_messages,
|
||||||
|
existing_summary=self.moving_summary_buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||||
|
new_messages.append(ai_message)
|
||||||
|
new_messages.append(function_message)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
|
||||||
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||||
|
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||||
|
llm = cast(ChatOpenAI, llm)
|
||||||
|
model, encoding = llm._get_encoding_model()
|
||||||
|
if model.startswith("gpt-3.5-turbo"):
|
||||||
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||||
|
tokens_per_message = 4
|
||||||
|
# if there's a name, the role is omitted
|
||||||
|
tokens_per_name = -1
|
||||||
|
elif model.startswith("gpt-4"):
|
||||||
|
tokens_per_message = 3
|
||||||
|
tokens_per_name = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"get_num_tokens_from_messages() is not presently implemented "
|
||||||
|
f"for model {model}."
|
||||||
|
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||||
|
"information on how messages are converted to tokens."
|
||||||
|
)
|
||||||
|
num_tokens = 0
|
||||||
|
for m in messages:
|
||||||
|
message = _convert_message_to_dict(m)
|
||||||
|
num_tokens += tokens_per_message
|
||||||
|
for key, value in message.items():
|
||||||
|
if key == "function_call":
|
||||||
|
for f_key, f_value in value.items():
|
||||||
|
num_tokens += len(encoding.encode(f_key))
|
||||||
|
num_tokens += len(encoding.encode(f_value))
|
||||||
|
else:
|
||||||
|
num_tokens += len(encoding.encode(value))
|
||||||
|
|
||||||
|
if key == "name":
|
||||||
|
num_tokens += tokens_per_name
|
||||||
|
# every reply is primed with <im_start>assistant
|
||||||
|
num_tokens += 3
|
||||||
|
|
||||||
|
if kwargs.get('functions'):
|
||||||
|
for function in kwargs.get('functions'):
|
||||||
|
num_tokens += len(encoding.encode('name'))
|
||||||
|
num_tokens += len(encoding.encode(function.get("name")))
|
||||||
|
num_tokens += len(encoding.encode('description'))
|
||||||
|
num_tokens += len(encoding.encode(function.get("description")))
|
||||||
|
parameters = function.get("parameters")
|
||||||
|
num_tokens += len(encoding.encode('parameters'))
|
||||||
|
if 'title' in parameters:
|
||||||
|
num_tokens += len(encoding.encode('title'))
|
||||||
|
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||||
|
num_tokens += len(encoding.encode('type'))
|
||||||
|
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||||
|
if 'properties' in parameters:
|
||||||
|
num_tokens += len(encoding.encode('properties'))
|
||||||
|
for key, value in parameters.get('properties').items():
|
||||||
|
num_tokens += len(encoding.encode(key))
|
||||||
|
for field_key, field_value in value.items():
|
||||||
|
num_tokens += len(encoding.encode(field_key))
|
||||||
|
if field_key == 'enum':
|
||||||
|
for enum_field in field_value:
|
||||||
|
num_tokens += 3
|
||||||
|
num_tokens += len(encoding.encode(enum_field))
|
||||||
|
else:
|
||||||
|
num_tokens += len(encoding.encode(field_key))
|
||||||
|
num_tokens += len(encoding.encode(str(field_value)))
|
||||||
|
if 'required' in parameters:
|
||||||
|
num_tokens += len(encoding.encode('required'))
|
||||||
|
for required_field in parameters['required']:
|
||||||
|
num_tokens += 3
|
||||||
|
num_tokens += len(encoding.encode(required_field))
|
||||||
|
|
||||||
|
return num_tokens
|
||||||
@ -0,0 +1,109 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||||
|
|
||||||
|
import pytz
|
||||||
|
from langchain.agents import BaseMultiActionAgent
|
||||||
|
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
|
||||||
|
_parse_ai_message
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||||
|
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||||
|
|
||||||
|
|
||||||
|
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm_and_tools(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||||
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
|
content="You are a helpful AI assistant."
|
||||||
|
),
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseMultiActionAgent:
|
||||||
|
return super().from_llm_and_tools(
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
extra_prompt_messages=extra_prompt_messages,
|
||||||
|
system_message=cls.get_system_message(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_use_agent(self, query: str):
|
||||||
|
"""
|
||||||
|
return should use agent
|
||||||
|
|
||||||
|
:param query:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
original_max_tokens = self.llm.max_tokens
|
||||||
|
self.llm.max_tokens = 15
|
||||||
|
|
||||||
|
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||||
|
messages = prompt.to_messages()
|
||||||
|
|
||||||
|
predicted_message = self.llm.predict_messages(
|
||||||
|
messages, functions=self.functions, callbacks=None
|
||||||
|
)
|
||||||
|
|
||||||
|
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||||
|
|
||||||
|
self.llm.max_tokens = original_max_tokens
|
||||||
|
|
||||||
|
return True if function_call else False
|
||||||
|
|
||||||
|
def plan(
|
||||||
|
self,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||||
|
**kwargs: User inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action specifying what tool to use.
|
||||||
|
"""
|
||||||
|
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||||
|
selected_inputs = {
|
||||||
|
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||||
|
}
|
||||||
|
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||||
|
prompt = self.prompt.format_prompt(**full_inputs)
|
||||||
|
messages = prompt.to_messages()
|
||||||
|
|
||||||
|
# summarize messages if rest_tokens < 0
|
||||||
|
try:
|
||||||
|
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
|
||||||
|
except ExceededLLMTokensLimitError as e:
|
||||||
|
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||||
|
|
||||||
|
predicted_message = self.llm.predict_messages(
|
||||||
|
messages, functions=self.functions, callbacks=callbacks
|
||||||
|
)
|
||||||
|
agent_decision = _parse_ai_message(predicted_message)
|
||||||
|
return agent_decision
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_system_message(cls):
|
||||||
|
# get current time
|
||||||
|
current_time = datetime.now()
|
||||||
|
current_timezone = pytz.timezone('UTC')
|
||||||
|
current_time = current_timezone.localize(current_time)
|
||||||
|
|
||||||
|
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||||
|
"Current time: {}\n"
|
||||||
|
"Respond directly if appropriate.".format(
|
||||||
|
current_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")))
|
||||||
@ -0,0 +1,29 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser, \
|
||||||
|
logger
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredChatOutputParser(LCStructuredChatOutputParser):
|
||||||
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||||
|
try:
|
||||||
|
action_match = re.search(r"```(.*?)\n(.*?)```?", text, re.DOTALL)
|
||||||
|
if action_match is not None:
|
||||||
|
response = json.loads(action_match.group(2).strip(), strict=False)
|
||||||
|
if isinstance(response, list):
|
||||||
|
# gpt turbo frequently ignores the directive to emit a single action
|
||||||
|
logger.warning("Got multiple action responses: %s", response)
|
||||||
|
response = response[0]
|
||||||
|
if response["action"] == "Final Answer":
|
||||||
|
return AgentFinish({"output": response["action_input"]}, text)
|
||||||
|
else:
|
||||||
|
return AgentAction(
|
||||||
|
response["action"], response.get("action_input", {}), text
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return AgentFinish({"output": text}, text)
|
||||||
|
except Exception as e:
|
||||||
|
raise OutputParserException(f"Could not parse LLM output: {text}") from e
|
||||||
@ -0,0 +1,182 @@
|
|||||||
|
import re
|
||||||
|
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||||
|
|
||||||
|
from langchain import BasePromptTemplate
|
||||||
|
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||||
|
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.memory.summary import SummarizerMixin
|
||||||
|
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||||
|
|
||||||
|
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||||
|
|
||||||
|
|
||||||
|
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||||
|
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||||
|
Valid "action" values: "Final Answer" or {tool_names}
|
||||||
|
|
||||||
|
Provide only ONE action per $JSON_BLOB, as shown:
|
||||||
|
|
||||||
|
```
|
||||||
|
{{{{
|
||||||
|
"action": $TOOL_NAME,
|
||||||
|
"action_input": $INPUT
|
||||||
|
}}}}
|
||||||
|
```
|
||||||
|
|
||||||
|
Follow this format:
|
||||||
|
|
||||||
|
Question: input question to answer
|
||||||
|
Thought: consider previous and subsequent steps
|
||||||
|
Action:
|
||||||
|
```
|
||||||
|
$JSON_BLOB
|
||||||
|
```
|
||||||
|
Observation: action result
|
||||||
|
... (repeat Thought/Action/Observation N times)
|
||||||
|
Thought: I know what to respond
|
||||||
|
Action:
|
||||||
|
```
|
||||||
|
{{{{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": "Final response to human"
|
||||||
|
}}}}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
|
||||||
|
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||||
|
moving_summary_buffer: str = ""
|
||||||
|
moving_summary_index: int = 0
|
||||||
|
summary_llm: BaseLanguageModel
|
||||||
|
|
||||||
|
def should_use_agent(self, query: str):
|
||||||
|
"""
|
||||||
|
return should use agent
|
||||||
|
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||||
|
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||||
|
|
||||||
|
:param query:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def plan(
|
||||||
|
self,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intermediate_steps: Steps the LLM has taken to date,
|
||||||
|
along with observations
|
||||||
|
callbacks: Callbacks to run.
|
||||||
|
**kwargs: User inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action specifying what tool to use.
|
||||||
|
"""
|
||||||
|
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||||
|
|
||||||
|
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
|
||||||
|
messages = []
|
||||||
|
if prompts:
|
||||||
|
messages = prompts[0].to_messages()
|
||||||
|
|
||||||
|
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
|
||||||
|
if rest_tokens < 0:
|
||||||
|
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||||
|
|
||||||
|
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||||
|
return self.output_parser.parse(full_output)
|
||||||
|
|
||||||
|
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||||
|
if len(intermediate_steps) >= 2:
|
||||||
|
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||||
|
should_summary_messages = [AIMessage(content=observation)
|
||||||
|
for _, observation in should_summary_intermediate_steps]
|
||||||
|
if self.moving_summary_index == 0:
|
||||||
|
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
|
||||||
|
|
||||||
|
self.moving_summary_index = len(intermediate_steps)
|
||||||
|
else:
|
||||||
|
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||||
|
raise ExceededLLMTokensLimitError(error_msg)
|
||||||
|
|
||||||
|
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||||
|
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||||
|
kwargs["chat_history"].pop()
|
||||||
|
|
||||||
|
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||||
|
messages=should_summary_messages,
|
||||||
|
existing_summary=self.moving_summary_buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'chat_history' in kwargs:
|
||||||
|
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
|
||||||
|
|
||||||
|
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_prompt(
|
||||||
|
cls,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
prefix: str = PREFIX,
|
||||||
|
suffix: str = SUFFIX,
|
||||||
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||||
|
) -> BasePromptTemplate:
|
||||||
|
tool_strings = []
|
||||||
|
for tool in tools:
|
||||||
|
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||||
|
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||||
|
formatted_tools = "\n".join(tool_strings)
|
||||||
|
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
|
||||||
|
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||||
|
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||||
|
if input_variables is None:
|
||||||
|
input_variables = ["input", "agent_scratchpad"]
|
||||||
|
_memory_prompts = memory_prompts or []
|
||||||
|
messages = [
|
||||||
|
SystemMessagePromptTemplate.from_template(template),
|
||||||
|
*_memory_prompts,
|
||||||
|
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||||
|
]
|
||||||
|
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm_and_tools(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
output_parser: Optional[AgentOutputParser] = None,
|
||||||
|
prefix: str = PREFIX,
|
||||||
|
suffix: str = SUFFIX,
|
||||||
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Agent:
|
||||||
|
return super().from_llm_and_tools(
|
||||||
|
llm=llm,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
output_parser=output_parser,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
human_message_template=human_message_template,
|
||||||
|
format_instructions=format_instructions,
|
||||||
|
input_variables=input_variables,
|
||||||
|
memory_prompts=memory_prompts,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
@ -1,86 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from langchain import LLMChain
|
|
||||||
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
|
||||||
|
|
||||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
|
||||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
|
||||||
from core.llm.llm_builder import LLMBuilder
|
|
||||||
|
|
||||||
|
|
||||||
class AgentBuilder:
|
|
||||||
@classmethod
|
|
||||||
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
|
|
||||||
dataset_tool_callback_handler: DatasetToolCallbackHandler,
|
|
||||||
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
|
|
||||||
llm = LLMBuilder.to_llm(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_name=agent_loop_gather_callback_handler.model_name,
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=1024,
|
|
||||||
callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool in tools:
|
|
||||||
tool.callbacks = [
|
|
||||||
agent_loop_gather_callback_handler,
|
|
||||||
dataset_tool_callback_handler,
|
|
||||||
DifyStdOutCallbackHandler()
|
|
||||||
]
|
|
||||||
|
|
||||||
prompt = cls.build_agent_prompt_template(
|
|
||||||
tools=tools,
|
|
||||||
memory=memory,
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_llm_chain = LLMChain(
|
|
||||||
llm=llm,
|
|
||||||
prompt=prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory)
|
|
||||||
|
|
||||||
agent_callback_manager = CallbackManager(
|
|
||||||
[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_chain = AgentExecutor.from_agent_and_tools(
|
|
||||||
tools=tools,
|
|
||||||
agent=agent,
|
|
||||||
memory=memory,
|
|
||||||
callbacks=agent_callback_manager,
|
|
||||||
max_iterations=6,
|
|
||||||
early_stopping_method="generate",
|
|
||||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
|
||||||
)
|
|
||||||
|
|
||||||
return agent_chain
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]):
|
|
||||||
if memory:
|
|
||||||
prompt = ConversationalAgent.create_prompt(
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = ZeroShotAgent.create_prompt(
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]):
|
|
||||||
if memory:
|
|
||||||
agent = ConversationalAgent(
|
|
||||||
llm_chain=agent_llm_chain
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
agent = ZeroShotAgent(
|
|
||||||
llm_chain=agent_llm_chain
|
|
||||||
)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
@ -0,0 +1,121 @@
|
|||||||
|
import enum
|
||||||
|
import logging
|
||||||
|
from typing import Union, Optional
|
||||||
|
|
||||||
|
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||||
|
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||||
|
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
|
||||||
|
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||||
|
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||||
|
from langchain.agents import AgentExecutor as LCAgentExecutor
|
||||||
|
|
||||||
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
|
|
||||||
|
|
||||||
|
class PlanningStrategy(str, enum.Enum):
|
||||||
|
ROUTER = 'router'
|
||||||
|
REACT = 'react'
|
||||||
|
FUNCTION_CALL = 'function_call'
|
||||||
|
MULTI_FUNCTION_CALL = 'multi_function_call'
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfiguration(BaseModel):
|
||||||
|
strategy: PlanningStrategy
|
||||||
|
llm: BaseLanguageModel
|
||||||
|
tools: list[BaseTool]
|
||||||
|
summary_llm: BaseLanguageModel
|
||||||
|
memory: Optional[BaseChatMemory] = None
|
||||||
|
callbacks: Callbacks = None
|
||||||
|
max_iterations: int = 6
|
||||||
|
max_execution_time: Optional[float] = None
|
||||||
|
early_stopping_method: str = "generate"
|
||||||
|
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class AgentExecuteResult(BaseModel):
|
||||||
|
strategy: PlanningStrategy
|
||||||
|
output: Optional[str]
|
||||||
|
configuration: AgentConfiguration
|
||||||
|
|
||||||
|
|
||||||
|
class AgentExecutor:
|
||||||
|
def __init__(self, configuration: AgentConfiguration):
|
||||||
|
self.configuration = configuration
|
||||||
|
self.agent = self._init_agent()
|
||||||
|
|
||||||
|
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
|
||||||
|
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||||
|
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||||
|
llm=self.configuration.llm,
|
||||||
|
tools=self.configuration.tools,
|
||||||
|
output_parser=StructuredChatOutputParser(),
|
||||||
|
summary_llm=self.configuration.summary_llm,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||||
|
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||||
|
llm=self.configuration.llm,
|
||||||
|
tools=self.configuration.tools,
|
||||||
|
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||||
|
summary_llm=self.configuration.summary_llm,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
|
||||||
|
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
|
||||||
|
llm=self.configuration.llm,
|
||||||
|
tools=self.configuration.tools,
|
||||||
|
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||||
|
summary_llm=self.configuration.summary_llm,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||||
|
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||||
|
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||||
|
llm=self.configuration.llm,
|
||||||
|
tools=self.configuration.tools,
|
||||||
|
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def should_use_agent(self, query: str) -> bool:
|
||||||
|
return self.agent.should_use_agent(query)
|
||||||
|
|
||||||
|
def run(self, query: str) -> AgentExecuteResult:
|
||||||
|
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
||||||
|
agent=self.agent,
|
||||||
|
tools=self.configuration.tools,
|
||||||
|
memory=self.configuration.memory,
|
||||||
|
max_iterations=self.configuration.max_iterations,
|
||||||
|
max_execution_time=self.configuration.max_execution_time,
|
||||||
|
early_stopping_method=self.configuration.early_stopping_method,
|
||||||
|
callbacks=self.configuration.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = agent_executor.run(query)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("agent_executor run failed")
|
||||||
|
output = None
|
||||||
|
|
||||||
|
return AgentExecuteResult(
|
||||||
|
output=output,
|
||||||
|
strategy=self.configuration.strategy,
|
||||||
|
configuration=self.configuration
|
||||||
|
)
|
||||||
@ -1,32 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
|
||||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
|
||||||
from core.chain.tool_chain import ToolChain
|
|
||||||
|
|
||||||
|
|
||||||
class ChainBuilder:
|
|
||||||
@classmethod
|
|
||||||
def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
|
|
||||||
return ToolChain(
|
|
||||||
tool=tool,
|
|
||||||
input_key=kwargs.get('input_key', 'input'),
|
|
||||||
output_key=kwargs.get('output_key', 'tool_output'),
|
|
||||||
callbacks=[DifyStdOutCallbackHandler()]
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
|
|
||||||
SensitiveWordAvoidanceChain]:
|
|
||||||
sensitive_words = tool_config.get("words", "")
|
|
||||||
if tool_config.get("enabled", False) \
|
|
||||||
and sensitive_words:
|
|
||||||
return SensitiveWordAvoidanceChain(
|
|
||||||
sensitive_words=sensitive_words.split(","),
|
|
||||||
canned_response=tool_config.get("canned_response", ''),
|
|
||||||
output_key="sensitive_word_avoidance_output",
|
|
||||||
callbacks=[DifyStdOutCallbackHandler()],
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
@ -1,111 +0,0 @@
|
|||||||
"""Base classes for LLM-powered router chains."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
|
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from pydantic import root_validator
|
|
||||||
|
|
||||||
from langchain.chains import LLMChain
|
|
||||||
from langchain.prompts import BasePromptTemplate
|
|
||||||
from langchain.schema import BaseOutputParser, OutputParserException
|
|
||||||
|
|
||||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
|
||||||
|
|
||||||
|
|
||||||
class Route(NamedTuple):
|
|
||||||
destination: Optional[str]
|
|
||||||
next_inputs: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class LLMRouterChain(Chain):
|
|
||||||
"""A router chain that uses an LLM chain to perform routing."""
|
|
||||||
|
|
||||||
llm_chain: LLMChain
|
|
||||||
"""LLM chain used to perform routing"""
|
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def validate_prompt(cls, values: dict) -> dict:
|
|
||||||
prompt = values["llm_chain"].prompt
|
|
||||||
if prompt.output_parser is None:
|
|
||||||
raise ValueError(
|
|
||||||
"LLMRouterChain requires base llm_chain prompt to have an output"
|
|
||||||
" parser that converts LLM text output to a dictionary with keys"
|
|
||||||
" 'destination' and 'next_inputs'. Received a prompt with no output"
|
|
||||||
" parser."
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Will be whatever keys the LLM chain prompt expects.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return self.llm_chain.input_keys
|
|
||||||
|
|
||||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
|
||||||
super()._validate_outputs(outputs)
|
|
||||||
if not isinstance(outputs["next_inputs"], dict):
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
output = cast(
|
|
||||||
Dict[str, Any],
|
|
||||||
self.llm_chain.predict_and_parse(**inputs),
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
|
|
||||||
) -> LLMRouterChain:
|
|
||||||
"""Convenience constructor."""
|
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
||||||
return cls(llm_chain=llm_chain, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
return ["destination", "next_inputs"]
|
|
||||||
|
|
||||||
def route(self, inputs: Dict[str, Any]) -> Route:
|
|
||||||
result = self(inputs)
|
|
||||||
return Route(result["destination"], result["next_inputs"])
|
|
||||||
|
|
||||||
|
|
||||||
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
|
||||||
"""Parser for output of router chain int he multi-prompt chain."""
|
|
||||||
|
|
||||||
default_destination: str = "DEFAULT"
|
|
||||||
next_inputs_type: Type = str
|
|
||||||
next_inputs_inner_key: str = "input"
|
|
||||||
|
|
||||||
def parse(self, text: str) -> Dict[str, Any]:
|
|
||||||
try:
|
|
||||||
expected_keys = ["destination", "next_inputs"]
|
|
||||||
parsed = parse_and_check_json_markdown(text, expected_keys)
|
|
||||||
if not isinstance(parsed["destination"], str):
|
|
||||||
raise ValueError("Expected 'destination' to be a string.")
|
|
||||||
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected 'next_inputs' to be {self.next_inputs_type}."
|
|
||||||
)
|
|
||||||
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
|
|
||||||
if (
|
|
||||||
parsed["destination"].strip().lower()
|
|
||||||
== self.default_destination.lower()
|
|
||||||
):
|
|
||||||
parsed["destination"] = None
|
|
||||||
else:
|
|
||||||
parsed["destination"] = parsed["destination"].strip()
|
|
||||||
return parsed
|
|
||||||
except Exception as e:
|
|
||||||
raise OutputParserException(
|
|
||||||
f"Parsing text\n{text}\n of llm router raised following error:\n{e}"
|
|
||||||
)
|
|
||||||
@ -1,110 +0,0 @@
|
|||||||
from typing import Optional, List, cast
|
|
||||||
|
|
||||||
from langchain.chains import SequentialChain
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
|
||||||
|
|
||||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
|
||||||
from core.chain.chain_builder import ChainBuilder
|
|
||||||
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
|
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset
|
|
||||||
|
|
||||||
|
|
||||||
class MainChainBuilder:
|
|
||||||
@classmethod
|
|
||||||
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
|
|
||||||
rest_tokens: int,
|
|
||||||
conversation_message_task: ConversationMessageTask):
|
|
||||||
first_input_key = "input"
|
|
||||||
final_output_key = "output"
|
|
||||||
|
|
||||||
chains = []
|
|
||||||
|
|
||||||
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
|
|
||||||
|
|
||||||
# agent mode
|
|
||||||
tool_chains, chains_output_key = cls.get_agent_chains(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
agent_mode=agent_mode,
|
|
||||||
rest_tokens=rest_tokens,
|
|
||||||
memory=memory,
|
|
||||||
conversation_message_task=conversation_message_task
|
|
||||||
)
|
|
||||||
chains += tool_chains
|
|
||||||
|
|
||||||
if chains_output_key:
|
|
||||||
final_output_key = chains_output_key
|
|
||||||
|
|
||||||
if len(chains) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for chain in chains:
|
|
||||||
chain = cast(Chain, chain)
|
|
||||||
chain.callbacks.append(chain_callback_handler)
|
|
||||||
|
|
||||||
# build main chain
|
|
||||||
overall_chain = SequentialChain(
|
|
||||||
chains=chains,
|
|
||||||
input_variables=[first_input_key],
|
|
||||||
output_variables=[final_output_key],
|
|
||||||
memory=memory, # only for use the memory prompt input key
|
|
||||||
)
|
|
||||||
|
|
||||||
return overall_chain
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
|
|
||||||
rest_tokens: int,
|
|
||||||
memory: Optional[BaseChatMemory],
|
|
||||||
conversation_message_task: ConversationMessageTask):
|
|
||||||
# agent mode
|
|
||||||
chains = []
|
|
||||||
if agent_mode and agent_mode.get('enabled'):
|
|
||||||
tools = agent_mode.get('tools', [])
|
|
||||||
|
|
||||||
pre_fixed_chains = []
|
|
||||||
# agent_tools = []
|
|
||||||
datasets = []
|
|
||||||
for tool in tools:
|
|
||||||
tool_type = list(tool.keys())[0]
|
|
||||||
tool_config = list(tool.values())[0]
|
|
||||||
if tool_type == 'sensitive-word-avoidance':
|
|
||||||
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
|
|
||||||
if chain:
|
|
||||||
pre_fixed_chains.append(chain)
|
|
||||||
elif tool_type == "dataset":
|
|
||||||
# get dataset from dataset id
|
|
||||||
dataset = db.session.query(Dataset).filter(
|
|
||||||
Dataset.tenant_id == tenant_id,
|
|
||||||
Dataset.id == tool_config.get("id")
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if dataset:
|
|
||||||
datasets.append(dataset)
|
|
||||||
|
|
||||||
# add pre-fixed chains
|
|
||||||
chains += pre_fixed_chains
|
|
||||||
|
|
||||||
if len(datasets) > 0:
|
|
||||||
# tool to chain
|
|
||||||
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
datasets=datasets,
|
|
||||||
conversation_message_task=conversation_message_task,
|
|
||||||
rest_tokens=rest_tokens,
|
|
||||||
callbacks=[DifyStdOutCallbackHandler()]
|
|
||||||
)
|
|
||||||
chains.append(multi_dataset_router_chain)
|
|
||||||
|
|
||||||
final_output_key = cls.get_chains_output_key(chains)
|
|
||||||
|
|
||||||
return chains, final_output_key
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_chains_output_key(cls, chains: List[Chain]):
|
|
||||||
if len(chains) > 0:
|
|
||||||
return chains[-1].output_keys[0]
|
|
||||||
return None
|
|
||||||
@ -1,198 +0,0 @@
|
|||||||
import math
|
|
||||||
import re
|
|
||||||
from typing import Mapping, List, Dict, Any, Optional
|
|
||||||
|
|
||||||
from langchain import PromptTemplate
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from pydantic import Extra
|
|
||||||
|
|
||||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
|
||||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
|
||||||
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
|
|
||||||
from core.conversation_message_task import ConversationMessageTask
|
|
||||||
from core.llm.llm_builder import LLMBuilder
|
|
||||||
from core.tool.dataset_index_tool import DatasetTool
|
|
||||||
from models.dataset import Dataset, DatasetProcessRule
|
|
||||||
|
|
||||||
DEFAULT_K = 2
|
|
||||||
CONTEXT_TOKENS_PERCENT = 0.3
|
|
||||||
MULTI_PROMPT_ROUTER_TEMPLATE = """
|
|
||||||
Given a raw text input to a language model select the model prompt best suited for \
|
|
||||||
the input. You will be given the names of the available prompts and a description of \
|
|
||||||
what the prompt is best suited for. You may also revise the original input if you \
|
|
||||||
think that revising it will ultimately lead to a better response from the language \
|
|
||||||
model.
|
|
||||||
|
|
||||||
<< FORMATTING >>
|
|
||||||
Return a markdown code snippet with a JSON object formatted to look like, \
|
|
||||||
no any other string out of markdown code snippet:
|
|
||||||
```json
|
|
||||||
{{{{
|
|
||||||
"destination": string \\ name of the prompt to use or "DEFAULT"
|
|
||||||
"next_inputs": string \\ a potentially modified version of the original input
|
|
||||||
}}}}
|
|
||||||
```
|
|
||||||
|
|
||||||
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
|
|
||||||
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
|
|
||||||
REMEMBER: "next_inputs" can just be the original input if you don't think any \
|
|
||||||
modifications are needed.
|
|
||||||
|
|
||||||
<< CANDIDATE PROMPTS >>
|
|
||||||
{destinations}
|
|
||||||
|
|
||||||
<< INPUT >>
|
|
||||||
{{input}}
|
|
||||||
|
|
||||||
<< OUTPUT >>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class MultiDatasetRouterChain(Chain):
|
|
||||||
"""Use a single chain to route an input to one of multiple candidate chains."""
|
|
||||||
|
|
||||||
router_chain: LLMRouterChain
|
|
||||||
"""Chain for deciding a destination chain and the input to it."""
|
|
||||||
dataset_tools: Mapping[str, DatasetTool]
|
|
||||||
"""Map of name to candidate chains that inputs can be routed to."""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Will be whatever keys the router chain prompt expects.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return self.router_chain.input_keys
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
return ["text"]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_datasets(
|
|
||||||
cls,
|
|
||||||
tenant_id: str,
|
|
||||||
datasets: List[Dataset],
|
|
||||||
conversation_message_task: ConversationMessageTask,
|
|
||||||
rest_tokens: int,
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
"""Convenience constructor for instantiating from destination prompts."""
|
|
||||||
llm = LLMBuilder.to_llm(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
model_name='gpt-3.5-turbo',
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=1024,
|
|
||||||
callbacks=[DifyStdOutCallbackHandler()]
|
|
||||||
)
|
|
||||||
|
|
||||||
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
|
|
||||||
else ('useful for when you want to answer queries about the ' + d.name))
|
|
||||||
for d in datasets]
|
|
||||||
destinations_str = "\n".join(destinations)
|
|
||||||
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
|
|
||||||
destinations=destinations_str
|
|
||||||
)
|
|
||||||
|
|
||||||
router_prompt = PromptTemplate(
|
|
||||||
template=router_template,
|
|
||||||
input_variables=["input"],
|
|
||||||
output_parser=RouterOutputParser(),
|
|
||||||
)
|
|
||||||
|
|
||||||
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
|
|
||||||
dataset_tools = {}
|
|
||||||
for dataset in datasets:
|
|
||||||
# fulfill description when it is empty
|
|
||||||
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
description = dataset.description
|
|
||||||
if not description:
|
|
||||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
|
||||||
|
|
||||||
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
|
|
||||||
if k == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
dataset_tool = DatasetTool(
|
|
||||||
name=f"dataset-{dataset.id}",
|
|
||||||
description=description,
|
|
||||||
k=k,
|
|
||||||
dataset=dataset,
|
|
||||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_tools[str(dataset.id)] = dataset_tool
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
router_chain=router_chain,
|
|
||||||
dataset_tools=dataset_tools,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
|
|
||||||
processing_rule = dataset.latest_process_rule
|
|
||||||
if not processing_rule:
|
|
||||||
return DEFAULT_K
|
|
||||||
|
|
||||||
if processing_rule.mode == "custom":
|
|
||||||
rules = processing_rule.rules_dict
|
|
||||||
if not rules:
|
|
||||||
return DEFAULT_K
|
|
||||||
|
|
||||||
segmentation = rules["segmentation"]
|
|
||||||
segment_max_tokens = segmentation["max_tokens"]
|
|
||||||
else:
|
|
||||||
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
|
|
||||||
|
|
||||||
# when rest_tokens is less than default context tokens
|
|
||||||
if rest_tokens < segment_max_tokens * DEFAULT_K:
|
|
||||||
return rest_tokens // segment_max_tokens
|
|
||||||
|
|
||||||
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
|
|
||||||
|
|
||||||
# when context_limit_tokens is less than default context tokens, use default_k
|
|
||||||
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
|
|
||||||
return DEFAULT_K
|
|
||||||
|
|
||||||
# Expand the k value when there's still some room left in the 30% rest tokens space
|
|
||||||
return context_limit_tokens // segment_max_tokens
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
if len(self.dataset_tools) == 0:
|
|
||||||
return {"text": ''}
|
|
||||||
elif len(self.dataset_tools) == 1:
|
|
||||||
return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
|
|
||||||
|
|
||||||
route = self.router_chain.route(inputs)
|
|
||||||
|
|
||||||
destination = ''
|
|
||||||
if route.destination:
|
|
||||||
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
|
|
||||||
match = re.search(pattern, route.destination, re.IGNORECASE)
|
|
||||||
if match:
|
|
||||||
destination = match.group()
|
|
||||||
|
|
||||||
if not destination:
|
|
||||||
return {"text": ''}
|
|
||||||
elif destination in self.dataset_tools:
|
|
||||||
return {"text": self.dataset_tools[destination].run(
|
|
||||||
route.next_inputs['input']
|
|
||||||
)}
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Received invalid destination chain name '{destination}'"
|
|
||||||
)
|
|
||||||
@ -1,51 +0,0 @@
|
|||||||
from typing import List, Dict, Optional, Any
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
|
|
||||||
|
|
||||||
class ToolChain(Chain):
|
|
||||||
input_key: str = "input" #: :meta private:
|
|
||||||
output_key: str = "output" #: :meta private:
|
|
||||||
|
|
||||||
tool: BaseTool
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _chain_type(self) -> str:
|
|
||||||
return "tool_chain"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Expect input key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Return output key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
input = inputs[self.input_key]
|
|
||||||
output = self.tool.run(input, self.verbose)
|
|
||||||
return {self.output_key: output}
|
|
||||||
|
|
||||||
async def _acall(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Run the logic of this chain and return the output."""
|
|
||||||
input = inputs[self.input_key]
|
|
||||||
output = await self.tool.arun(input, self.verbose)
|
|
||||||
return {self.output_key: output}
|
|
||||||
@ -0,0 +1,59 @@
|
|||||||
|
import time
|
||||||
|
from typing import List, Optional, Any, Mapping
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.chat_models.base import SimpleChatModel
|
||||||
|
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration, BaseLanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLLM(SimpleChatModel):
|
||||||
|
"""Fake ChatModel for testing purposes."""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
response: str
|
||||||
|
origin_llm: Optional[BaseLanguageModel] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake-chat-model"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
return {"response": self.response}
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
return self.origin_llm.get_num_tokens(text) if self.origin_llm else 0
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
|
if self.streaming:
|
||||||
|
for token in output_str:
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(token)
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
message = AIMessage(content=output_str)
|
||||||
|
generation = ChatGeneration(message=message)
|
||||||
|
llm_output = {"token_usage": {
|
||||||
|
'prompt_tokens': 0,
|
||||||
|
'completion_tokens': 0,
|
||||||
|
'total_tokens': 0,
|
||||||
|
}}
|
||||||
|
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||||
@ -0,0 +1,277 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langchain import WikipediaAPIWrapper
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
|
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
|
||||||
|
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||||
|
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||||
|
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||||
|
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||||
|
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
|
||||||
|
from core.conversation_message_task import ConversationMessageTask
|
||||||
|
from core.llm.llm_builder import LLMBuilder
|
||||||
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
|
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||||
|
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
|
||||||
|
from core.tool.web_reader_tool import WebReaderTool
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
|
from models.model import AppModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OrchestratorRuleParser:
|
||||||
|
"""Parse the orchestrator rule to entities."""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.app_model_config = app_model_config
|
||||||
|
self.agent_summary_model_name = "gpt-3.5-turbo-16k"
|
||||||
|
|
||||||
|
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
|
||||||
|
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
|
||||||
|
-> Optional[AgentExecutor]:
|
||||||
|
if not self.app_model_config.agent_mode_dict:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent_mode_config = self.app_model_config.agent_mode_dict
|
||||||
|
model_dict = self.app_model_config.model_dict
|
||||||
|
|
||||||
|
chain = None
|
||||||
|
if agent_mode_config and agent_mode_config.get('enabled'):
|
||||||
|
tool_configs = agent_mode_config.get('tools', [])
|
||||||
|
agent_model_name = model_dict.get('name', 'gpt-4')
|
||||||
|
|
||||||
|
# add agent callback to record agent thoughts
|
||||||
|
agent_callback = AgentLoopGatherCallbackHandler(
|
||||||
|
model_name=agent_model_name,
|
||||||
|
conversation_message_task=conversation_message_task
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_callback.agent_callback = agent_callback
|
||||||
|
|
||||||
|
agent_llm = LLMBuilder.to_llm(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
model_name=agent_model_name,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=1500,
|
||||||
|
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
|
||||||
|
)
|
||||||
|
|
||||||
|
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
|
||||||
|
|
||||||
|
# only OpenAI chat model (include Azure) support function call, use ReACT instead
|
||||||
|
if not isinstance(agent_llm, ChatOpenAI) \
|
||||||
|
and planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
|
||||||
|
planning_strategy = PlanningStrategy.REACT
|
||||||
|
|
||||||
|
summary_llm = LLMBuilder.to_llm(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
model_name=self.agent_summary_model_name,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=500,
|
||||||
|
callbacks=[DifyStdOutCallbackHandler()]
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = self.to_tools(
|
||||||
|
tool_configs=tool_configs,
|
||||||
|
conversation_message_task=conversation_message_task,
|
||||||
|
model_name=self.agent_summary_model_name,
|
||||||
|
rest_tokens=rest_tokens,
|
||||||
|
callbacks=[agent_callback, DifyStdOutCallbackHandler()]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(tools) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent_configuration = AgentConfiguration(
|
||||||
|
strategy=planning_strategy,
|
||||||
|
llm=agent_llm,
|
||||||
|
tools=tools,
|
||||||
|
summary_llm=summary_llm,
|
||||||
|
memory=memory,
|
||||||
|
callbacks=[chain_callback, agent_callback],
|
||||||
|
max_iterations=10,
|
||||||
|
max_execution_time=400.0,
|
||||||
|
early_stopping_method="generate"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentExecutor(agent_configuration)
|
||||||
|
|
||||||
|
return chain
|
||||||
|
|
||||||
|
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
|
||||||
|
-> Optional[SensitiveWordAvoidanceChain]:
|
||||||
|
"""
|
||||||
|
Convert app sensitive word avoidance config to chain
|
||||||
|
|
||||||
|
:param kwargs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not self.app_model_config.sensitive_word_avoidance_dict:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
|
||||||
|
sensitive_words = sensitive_word_avoidance_config.get("words", "")
|
||||||
|
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
|
||||||
|
return SensitiveWordAvoidanceChain(
|
||||||
|
sensitive_words=sensitive_words.split(","),
|
||||||
|
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
|
||||||
|
output_key="sensitive_word_avoidance_output",
|
||||||
|
callbacks=callbacks,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
|
||||||
|
model_name: str, rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
|
||||||
|
"""
|
||||||
|
Convert app agent tool configs to tools
|
||||||
|
|
||||||
|
:param rest_tokens:
|
||||||
|
:param tool_configs: app agent tool configs
|
||||||
|
:param model_name:
|
||||||
|
:param conversation_message_task:
|
||||||
|
:param callbacks:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
for tool_config in tool_configs:
|
||||||
|
tool_type = list(tool_config.keys())[0]
|
||||||
|
tool_val = list(tool_config.values())[0]
|
||||||
|
if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool = None
|
||||||
|
if tool_type == "dataset":
|
||||||
|
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
|
||||||
|
elif tool_type == "web_reader":
|
||||||
|
tool = self.to_web_reader_tool(model_name)
|
||||||
|
elif tool_type == "google_search":
|
||||||
|
tool = self.to_google_search_tool()
|
||||||
|
elif tool_type == "wikipedia":
|
||||||
|
tool = self.to_wikipedia_tool()
|
||||||
|
|
||||||
|
if tool:
|
||||||
|
tool.callbacks.extend(callbacks)
|
||||||
|
tools.append(tool)
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
|
||||||
|
rest_tokens: int) \
|
||||||
|
-> Optional[BaseTool]:
|
||||||
|
"""
|
||||||
|
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||||
|
:param rest_tokens:
|
||||||
|
:param tool_config:
|
||||||
|
:param conversation_message_task:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# get dataset from dataset id
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == self.tenant_id,
|
||||||
|
Dataset.id == tool_config.get("id")
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
|
||||||
|
tool = DatasetRetrieverTool.from_dataset(
|
||||||
|
dataset=dataset,
|
||||||
|
k=k,
|
||||||
|
callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool
|
||||||
|
|
||||||
|
def to_web_reader_tool(self, model_name: str) -> Optional[BaseTool]:
|
||||||
|
"""
|
||||||
|
A tool for reading web pages
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
summary_llm = LLMBuilder.to_llm(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=500,
|
||||||
|
callbacks=[DifyStdOutCallbackHandler()]
|
||||||
|
)
|
||||||
|
|
||||||
|
tool = WebReaderTool(
|
||||||
|
llm=summary_llm,
|
||||||
|
max_chunk_length=4000,
|
||||||
|
continue_reading=True,
|
||||||
|
callbacks=[DifyStdOutCallbackHandler()]
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool
|
||||||
|
|
||||||
|
def to_google_search_tool(self) -> Optional[BaseTool]:
|
||||||
|
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
|
||||||
|
func_kwargs = tool_provider.credentials_to_func_kwargs()
|
||||||
|
if not func_kwargs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool = Tool(
|
||||||
|
name="google_search",
|
||||||
|
description="A tool for performing a Google search and extracting snippets and webpages "
|
||||||
|
"when you need to search for something you don't know or when your information "
|
||||||
|
"is not up to date."
|
||||||
|
"Input should be a search query.",
|
||||||
|
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
|
||||||
|
args_schema=OptimizedSerpAPIInput,
|
||||||
|
callbacks=[DifyStdOutCallbackHandler()]
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool
|
||||||
|
|
||||||
|
def to_wikipedia_tool(self) -> Optional[BaseTool]:
|
||||||
|
class WikipediaInput(BaseModel):
|
||||||
|
query: str = Field(..., description="search query.")
|
||||||
|
|
||||||
|
return WikipediaQueryRun(
|
||||||
|
name="wikipedia",
|
||||||
|
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
|
||||||
|
args_schema=WikipediaInput,
|
||||||
|
callbacks=[DifyStdOutCallbackHandler()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
|
||||||
|
DEFAULT_K = 2
|
||||||
|
CONTEXT_TOKENS_PERCENT = 0.3
|
||||||
|
processing_rule = dataset.latest_process_rule
|
||||||
|
if not processing_rule:
|
||||||
|
return DEFAULT_K
|
||||||
|
|
||||||
|
if processing_rule.mode == "custom":
|
||||||
|
rules = processing_rule.rules_dict
|
||||||
|
if not rules:
|
||||||
|
return DEFAULT_K
|
||||||
|
|
||||||
|
segmentation = rules["segmentation"]
|
||||||
|
segment_max_tokens = segmentation["max_tokens"]
|
||||||
|
else:
|
||||||
|
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
|
||||||
|
|
||||||
|
# when rest_tokens is less than default context tokens
|
||||||
|
if rest_tokens < segment_max_tokens * DEFAULT_K:
|
||||||
|
return rest_tokens // segment_max_tokens
|
||||||
|
|
||||||
|
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
|
||||||
|
|
||||||
|
# when context_limit_tokens is less than default context tokens, use default_k
|
||||||
|
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
|
||||||
|
return DEFAULT_K
|
||||||
|
|
||||||
|
# Expand the k value when there's still some room left in the 30% rest tokens space
|
||||||
|
return context_limit_tokens // segment_max_tokens
|
||||||
@ -1,87 +0,0 @@
|
|||||||
from flask import current_app
|
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
|
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
|
||||||
from core.embedding.cached_embedding import CacheEmbedding
|
|
||||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
|
||||||
from core.index.vector_index.vector_index import VectorIndex
|
|
||||||
from core.llm.llm_builder import LLMBuilder
|
|
||||||
from models.dataset import Dataset
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetTool(BaseTool):
|
|
||||||
"""Tool for querying a Dataset."""
|
|
||||||
|
|
||||||
dataset: Dataset
|
|
||||||
k: int = 2
|
|
||||||
|
|
||||||
def _run(self, tool_input: str) -> str:
|
|
||||||
if self.dataset.indexing_technique == "economy":
|
|
||||||
# use keyword table query
|
|
||||||
kw_table_index = KeywordTableIndex(
|
|
||||||
dataset=self.dataset,
|
|
||||||
config=KeywordTableConfig(
|
|
||||||
max_keywords_per_chunk=5
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
|
|
||||||
else:
|
|
||||||
model_credentials = LLMBuilder.get_model_credentials(
|
|
||||||
tenant_id=self.dataset.tenant_id,
|
|
||||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
|
||||||
model_name='text-embedding-ada-002'
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
|
||||||
**model_credentials
|
|
||||||
))
|
|
||||||
|
|
||||||
vector_index = VectorIndex(
|
|
||||||
dataset=self.dataset,
|
|
||||||
config=current_app.config,
|
|
||||||
embeddings=embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = vector_index.search(
|
|
||||||
tool_input,
|
|
||||||
search_type='similarity',
|
|
||||||
search_kwargs={
|
|
||||||
'k': self.k
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
|
||||||
hit_callback.on_tool_end(documents)
|
|
||||||
|
|
||||||
return str("\n".join([document.page_content for document in documents]))
|
|
||||||
|
|
||||||
async def _arun(self, tool_input: str) -> str:
|
|
||||||
model_credentials = LLMBuilder.get_model_credentials(
|
|
||||||
tenant_id=self.dataset.tenant_id,
|
|
||||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
|
||||||
model_name='text-embedding-ada-002'
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
|
||||||
**model_credentials
|
|
||||||
))
|
|
||||||
|
|
||||||
vector_index = VectorIndex(
|
|
||||||
dataset=self.dataset,
|
|
||||||
config=current_app.config,
|
|
||||||
embeddings=embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = await vector_index.asearch(
|
|
||||||
tool_input,
|
|
||||||
search_type='similarity',
|
|
||||||
search_kwargs={
|
|
||||||
'k': 10
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
|
||||||
hit_callback.on_tool_end(documents)
|
|
||||||
return str("\n".join([document.page_content for document in documents]))
|
|
||||||
@ -0,0 +1,105 @@
|
|||||||
|
import re
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
|
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||||
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
|
from core.llm.llm_builder import LLMBuilder
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetRetrieverToolInput(BaseModel):
|
||||||
|
dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
|
||||||
|
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetRetrieverTool(BaseTool):
|
||||||
|
"""Tool for querying a Dataset."""
|
||||||
|
name: str = "dataset"
|
||||||
|
args_schema: Type[BaseModel] = DatasetRetrieverToolInput
|
||||||
|
description: str = "use this to retrieve a dataset. "
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
dataset_id: str
|
||||||
|
k: int = 3
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||||
|
description = dataset.description.replace('\n', '').replace('\r', '')
|
||||||
|
if not description:
|
||||||
|
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||||
|
|
||||||
|
description += '\nID of dataset MUST be ' + dataset.id
|
||||||
|
return cls(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
description=description,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run(self, dataset_id: str, query: str) -> str:
|
||||||
|
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
|
||||||
|
match = re.search(pattern, dataset_id, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
dataset_id = match.group()
|
||||||
|
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == self.tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
return f'[{self.name} failed to find dataset with id {dataset_id}.]'
|
||||||
|
|
||||||
|
if dataset.indexing_technique == "economy":
|
||||||
|
# use keyword table query
|
||||||
|
kw_table_index = KeywordTableIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=KeywordTableConfig(
|
||||||
|
max_keywords_per_chunk=5
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
|
||||||
|
else:
|
||||||
|
model_credentials = LLMBuilder.get_model_credentials(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||||
|
model_name='text-embedding-ada-002'
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||||
|
**model_credentials
|
||||||
|
))
|
||||||
|
|
||||||
|
vector_index = VectorIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=current_app.config,
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.k > 0:
|
||||||
|
documents = vector_index.search(
|
||||||
|
query,
|
||||||
|
search_type='similarity',
|
||||||
|
search_kwargs={
|
||||||
|
'k': self.k
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
documents = []
|
||||||
|
|
||||||
|
hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
|
||||||
|
hit_callback.on_tool_end(documents)
|
||||||
|
|
||||||
|
return str("\n".join([document.page_content for document in documents]))
|
||||||
|
|
||||||
|
async def _arun(self, tool_input: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
@ -0,0 +1,63 @@
|
|||||||
|
import base64
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs import rsa
|
||||||
|
from models.account import Tenant
|
||||||
|
from models.tool import ToolProvider, ToolProviderName
|
||||||
|
|
||||||
|
|
||||||
|
class BaseToolProvider(ABC):
|
||||||
|
def __init__(self, tenant_id: str):
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_provider_name(self) -> ToolProviderName:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def credentials_validate(self, credentials: dict):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
|
||||||
|
"""
|
||||||
|
Returns the Provider instance for the given tenant_id and tool_name.
|
||||||
|
"""
|
||||||
|
query = db.session.query(ToolProvider).filter(
|
||||||
|
ToolProvider.tenant_id == self.tenant_id,
|
||||||
|
ToolProvider.tool_name == self.get_provider_name().value
|
||||||
|
)
|
||||||
|
|
||||||
|
if must_enabled:
|
||||||
|
query = query.filter(ToolProvider.is_enabled == True)
|
||||||
|
|
||||||
|
return query.first()
|
||||||
|
|
||||||
|
def encrypt_token(self, token) -> str:
|
||||||
|
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||||
|
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||||
|
return base64.b64encode(encrypted_token).decode()
|
||||||
|
|
||||||
|
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
|
||||||
|
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||||
|
|
||||||
|
if obfuscated:
|
||||||
|
return self._obfuscated_token(token)
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def _obfuscated_token(self, token: str) -> str:
|
||||||
|
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
class ToolValidateFailedError(Exception):
|
||||||
|
description = "Tool Provider Validate failed"
|
||||||
@ -0,0 +1,77 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.tool.provider.base import BaseToolProvider
|
||||||
|
from core.tool.provider.errors import ToolValidateFailedError
|
||||||
|
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
|
||||||
|
from models.tool import ToolProviderName
|
||||||
|
|
||||||
|
|
||||||
|
class SerpAPIToolProvider(BaseToolProvider):
|
||||||
|
def get_provider_name(self) -> ToolProviderName:
|
||||||
|
"""
|
||||||
|
Returns the name of the provider.
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return ToolProviderName.SERPAPI
|
||||||
|
|
||||||
|
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Returns the credentials for SerpAPI as a dictionary.
|
||||||
|
|
||||||
|
:param obfuscated: obfuscate credentials if True
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
tool_provider = self.get_provider(must_enabled=True)
|
||||||
|
if not tool_provider:
|
||||||
|
return None
|
||||||
|
|
||||||
|
credentials = tool_provider.credentials
|
||||||
|
if not credentials:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if credentials.get('api_key'):
|
||||||
|
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Returns the credentials function kwargs as a dictionary.
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
credentials = self.get_credentials()
|
||||||
|
if not credentials:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
'serpapi_api_key': credentials.get('api_key')
|
||||||
|
}
|
||||||
|
|
||||||
|
def credentials_validate(self, credentials: dict):
|
||||||
|
"""
|
||||||
|
Validates the given credentials.
|
||||||
|
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if 'api_key' not in credentials or not credentials.get('api_key'):
|
||||||
|
raise ToolValidateFailedError("SerpAPI api_key is required.")
|
||||||
|
|
||||||
|
api_key = credentials.get('api_key')
|
||||||
|
|
||||||
|
try:
|
||||||
|
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
|
||||||
|
|
||||||
|
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Encrypts the given credentials.
|
||||||
|
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
|
||||||
|
return credentials
|
||||||
@ -0,0 +1,43 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.tool.provider.base import BaseToolProvider
|
||||||
|
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProviderService:
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, provider_name: str):
|
||||||
|
self.provider = self._init_provider(tenant_id, provider_name)
|
||||||
|
|
||||||
|
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
|
||||||
|
if provider_name == 'serpapi':
|
||||||
|
return SerpAPIToolProvider(tenant_id)
|
||||||
|
else:
|
||||||
|
raise Exception('tool provider {} not found'.format(provider_name))
|
||||||
|
|
||||||
|
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Returns the credentials for Tool as a dictionary.
|
||||||
|
|
||||||
|
:param obfuscated:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self.provider.get_credentials(obfuscated)
|
||||||
|
|
||||||
|
def credentials_validate(self, credentials: dict):
|
||||||
|
"""
|
||||||
|
Validates the given credentials.
|
||||||
|
|
||||||
|
:param credentials:
|
||||||
|
:raises: ValidateFailedError
|
||||||
|
"""
|
||||||
|
return self.provider.credentials_validate(credentials)
|
||||||
|
|
||||||
|
def encrypt_credentials(self, credentials: dict):
|
||||||
|
"""
|
||||||
|
Encrypts the given credentials.
|
||||||
|
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self.provider.encrypt_credentials(credentials)
|
||||||
@ -0,0 +1,51 @@
|
|||||||
|
from langchain import SerpAPIWrapper
|
||||||
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizedSerpAPIInput(BaseModel):
|
||||||
|
query: str = Field(..., description="search query.")
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizedSerpAPIWrapper(SerpAPIWrapper):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_response(res: dict, num_results: int = 5) -> str:
|
||||||
|
"""Process response from SerpAPI."""
|
||||||
|
if "error" in res.keys():
|
||||||
|
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
||||||
|
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
|
||||||
|
res["answer_box"] = res["answer_box"][0]
|
||||||
|
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||||
|
toret = res["answer_box"]["answer"]
|
||||||
|
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||||
|
toret = res["answer_box"]["snippet"]
|
||||||
|
elif (
|
||||||
|
"answer_box" in res.keys()
|
||||||
|
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||||
|
):
|
||||||
|
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||||
|
elif (
|
||||||
|
"sports_results" in res.keys()
|
||||||
|
and "game_spotlight" in res["sports_results"].keys()
|
||||||
|
):
|
||||||
|
toret = res["sports_results"]["game_spotlight"]
|
||||||
|
elif (
|
||||||
|
"shopping_results" in res.keys()
|
||||||
|
and "title" in res["shopping_results"][0].keys()
|
||||||
|
):
|
||||||
|
toret = res["shopping_results"][:3]
|
||||||
|
elif (
|
||||||
|
"knowledge_graph" in res.keys()
|
||||||
|
and "description" in res["knowledge_graph"].keys()
|
||||||
|
):
|
||||||
|
toret = res["knowledge_graph"]["description"]
|
||||||
|
elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
|
||||||
|
toret = ""
|
||||||
|
for result in res["organic_results"][:num_results]:
|
||||||
|
if "link" in result:
|
||||||
|
toret += "----------------\nlink: " + result["link"] + "\n"
|
||||||
|
if "snippet" in result:
|
||||||
|
toret += "snippet: " + result["snippet"] + "\n"
|
||||||
|
else:
|
||||||
|
toret = "No good search result found"
|
||||||
|
return "search result:\n" + toret
|
||||||
@ -0,0 +1,419 @@
|
|||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import site
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import unicodedata
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from bs4 import BeautifulSoup, NavigableString, Comment, CData
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.chains.summarize import load_summarize_chain
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain.tools.base import BaseTool
|
||||||
|
from newspaper import Article
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from regex import regex
|
||||||
|
|
||||||
|
from core.data_loader import file_extractor
|
||||||
|
from core.data_loader.file_extractor import FileExtractor
|
||||||
|
|
||||||
|
FULL_TEMPLATE = """
|
||||||
|
TITLE: {title}
|
||||||
|
AUTHORS: {authors}
|
||||||
|
PUBLISH DATE: {publish_date}
|
||||||
|
TOP_IMAGE_URL: {top_image}
|
||||||
|
TEXT:
|
||||||
|
|
||||||
|
{text}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class WebReaderToolInput(BaseModel):
|
||||||
|
url: str = Field(..., description="URL of the website to read")
|
||||||
|
summary: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="When the user's question requires extracting the summarizing content of the webpage, "
|
||||||
|
"set it to true."
|
||||||
|
)
|
||||||
|
cursor: int = Field(
|
||||||
|
default=0,
|
||||||
|
description="Start reading from this character."
|
||||||
|
"Use when the first response was truncated"
|
||||||
|
"and you want to continue reading the page."
|
||||||
|
"The value cannot exceed 24000.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WebReaderTool(BaseTool):
|
||||||
|
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
|
||||||
|
|
||||||
|
name: str = "web_reader"
|
||||||
|
args_schema: Type[BaseModel] = WebReaderToolInput
|
||||||
|
description: str = "use this to read a website. " \
|
||||||
|
"If you can answer the question based on the information provided, " \
|
||||||
|
"there is no need to use."
|
||||||
|
page_contents: str = None
|
||||||
|
url: str = None
|
||||||
|
max_chunk_length: int = 4000
|
||||||
|
summary_chunk_tokens: int = 4000
|
||||||
|
summary_chunk_overlap: int = 0
|
||||||
|
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
|
||||||
|
continue_reading: bool = True
|
||||||
|
llm: BaseLanguageModel
|
||||||
|
|
||||||
|
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
|
||||||
|
try:
|
||||||
|
if not self.page_contents or self.url != url:
|
||||||
|
page_contents = get_url(url)
|
||||||
|
self.page_contents = page_contents
|
||||||
|
self.url = url
|
||||||
|
else:
|
||||||
|
page_contents = self.page_contents
|
||||||
|
except Exception as e:
|
||||||
|
return f'Read this website failed, caused by: {str(e)}.'
|
||||||
|
|
||||||
|
if summary:
|
||||||
|
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||||
|
chunk_size=self.summary_chunk_tokens,
|
||||||
|
chunk_overlap=self.summary_chunk_overlap,
|
||||||
|
separators=self.summary_separators
|
||||||
|
)
|
||||||
|
|
||||||
|
texts = character_splitter.split_text(page_contents)
|
||||||
|
docs = [Document(page_content=t) for t in texts]
|
||||||
|
|
||||||
|
# only use first 5 docs
|
||||||
|
if len(docs) > 5:
|
||||||
|
docs = docs[:5]
|
||||||
|
|
||||||
|
chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks)
|
||||||
|
try:
|
||||||
|
page_contents = chain.run(docs)
|
||||||
|
# todo use cache
|
||||||
|
except Exception as e:
|
||||||
|
return f'Read this website failed, caused by: {str(e)}.'
|
||||||
|
else:
|
||||||
|
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
|
||||||
|
|
||||||
|
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
|
||||||
|
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
|
||||||
|
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
|
||||||
|
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
|
||||||
|
|
||||||
|
return page_contents
|
||||||
|
|
||||||
|
async def _arun(self, url: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||||
|
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||||
|
return text[cursor: cursor + max_length]
|
||||||
|
|
||||||
|
|
||||||
|
def get_url(url: str) -> str:
|
||||||
|
"""Fetch URL and return the contents as a string."""
|
||||||
|
headers = {
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||||
|
}
|
||||||
|
supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||||
|
|
||||||
|
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
|
||||||
|
|
||||||
|
if head_response.status_code != 200:
|
||||||
|
return "URL returned status code {}.".format(head_response.status_code)
|
||||||
|
|
||||||
|
# check content-type
|
||||||
|
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
|
||||||
|
if main_content_type not in supported_content_types:
|
||||||
|
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||||
|
|
||||||
|
if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
|
||||||
|
return FileExtractor.load_from_url(url, return_text=True)
|
||||||
|
|
||||||
|
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
|
||||||
|
a = extract_using_readabilipy(response.text)
|
||||||
|
|
||||||
|
if not a['plain_text'] or not a['plain_text'].strip():
|
||||||
|
return get_url_from_newspaper3k(url)
|
||||||
|
|
||||||
|
res = FULL_TEMPLATE.format(
|
||||||
|
title=a['title'],
|
||||||
|
authors=a['byline'],
|
||||||
|
publish_date=a['date'],
|
||||||
|
top_image="",
|
||||||
|
text=a['plain_text'] if a['plain_text'] else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def get_url_from_newspaper3k(url: str) -> str:
|
||||||
|
|
||||||
|
a = Article(url)
|
||||||
|
a.download()
|
||||||
|
a.parse()
|
||||||
|
|
||||||
|
res = FULL_TEMPLATE.format(
|
||||||
|
title=a.title,
|
||||||
|
authors=a.authors,
|
||||||
|
publish_date=a.publish_date,
|
||||||
|
top_image=a.top_image,
|
||||||
|
text=a.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def extract_using_readabilipy(html):
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
|
||||||
|
f_html.write(html)
|
||||||
|
f_html.close()
|
||||||
|
html_path = f_html.name
|
||||||
|
|
||||||
|
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
||||||
|
article_json_path = html_path + ".json"
|
||||||
|
jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
|
||||||
|
with chdir(jsdir):
|
||||||
|
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
||||||
|
|
||||||
|
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
|
||||||
|
with open(article_json_path, "r", encoding="utf-8") as json_file:
|
||||||
|
input_json = json.loads(json_file.read())
|
||||||
|
|
||||||
|
# Deleting files after processing
|
||||||
|
os.unlink(article_json_path)
|
||||||
|
os.unlink(html_path)
|
||||||
|
|
||||||
|
article_json = {
|
||||||
|
"title": None,
|
||||||
|
"byline": None,
|
||||||
|
"date": None,
|
||||||
|
"content": None,
|
||||||
|
"plain_content": None,
|
||||||
|
"plain_text": None
|
||||||
|
}
|
||||||
|
# Populate article fields from readability fields where present
|
||||||
|
if input_json:
|
||||||
|
if "title" in input_json and input_json["title"]:
|
||||||
|
article_json["title"] = input_json["title"]
|
||||||
|
if "byline" in input_json and input_json["byline"]:
|
||||||
|
article_json["byline"] = input_json["byline"]
|
||||||
|
if "date" in input_json and input_json["date"]:
|
||||||
|
article_json["date"] = input_json["date"]
|
||||||
|
if "content" in input_json and input_json["content"]:
|
||||||
|
article_json["content"] = input_json["content"]
|
||||||
|
article_json["plain_content"] = plain_content(article_json["content"], False, False)
|
||||||
|
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
||||||
|
if "textContent" in input_json and input_json["textContent"]:
|
||||||
|
article_json["plain_text"] = input_json["textContent"]
|
||||||
|
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
|
||||||
|
|
||||||
|
return article_json
|
||||||
|
|
||||||
|
|
||||||
|
def find_module_path(module_name):
|
||||||
|
for package_path in site.getsitepackages():
|
||||||
|
potential_path = os.path.join(package_path, module_name)
|
||||||
|
if os.path.exists(potential_path):
|
||||||
|
return potential_path
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def chdir(path):
|
||||||
|
"""Change directory in context and return to original on exit"""
|
||||||
|
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
|
||||||
|
original_path = os.getcwd()
|
||||||
|
os.chdir(path)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
os.chdir(original_path)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_blocks_as_plain_text(paragraph_html):
|
||||||
|
# Load article as DOM
|
||||||
|
soup = BeautifulSoup(paragraph_html, 'html.parser')
|
||||||
|
# Select all lists
|
||||||
|
list_elements = soup.find_all(['ul', 'ol'])
|
||||||
|
# Prefix text in all list items with "* " and make lists paragraphs
|
||||||
|
for list_element in list_elements:
|
||||||
|
plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
|
||||||
|
list_element.string = plain_items
|
||||||
|
list_element.name = "p"
|
||||||
|
# Select all text blocks
|
||||||
|
text_blocks = [s.parent for s in soup.find_all(string=True)]
|
||||||
|
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
|
||||||
|
# Drop empty paragraphs
|
||||||
|
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
|
||||||
|
return text_blocks
|
||||||
|
|
||||||
|
|
||||||
|
def plain_text_leaf_node(element):
|
||||||
|
# Extract all text, stripped of any child HTML elements and normalise it
|
||||||
|
plain_text = normalise_text(element.get_text())
|
||||||
|
if plain_text != "" and element.name == "li":
|
||||||
|
plain_text = "* {}, ".format(plain_text)
|
||||||
|
if plain_text == "":
|
||||||
|
plain_text = None
|
||||||
|
if "data-node-index" in element.attrs:
|
||||||
|
plain = {"node_index": element["data-node-index"], "text": plain_text}
|
||||||
|
else:
|
||||||
|
plain = {"text": plain_text}
|
||||||
|
return plain
|
||||||
|
|
||||||
|
|
||||||
|
def plain_content(readability_content, content_digests, node_indexes):
|
||||||
|
# Load article as DOM
|
||||||
|
soup = BeautifulSoup(readability_content, 'html.parser')
|
||||||
|
# Make all elements plain
|
||||||
|
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
||||||
|
if node_indexes:
|
||||||
|
# Add node index attributes to nodes
|
||||||
|
elements = [add_node_indexes(element) for element in elements]
|
||||||
|
# Replace article contents with plain elements
|
||||||
|
soup.contents = elements
|
||||||
|
return str(soup)
|
||||||
|
|
||||||
|
|
||||||
|
def plain_elements(elements, content_digests, node_indexes):
|
||||||
|
# Get plain content versions of all elements
|
||||||
|
elements = [plain_element(element, content_digests, node_indexes)
|
||||||
|
for element in elements]
|
||||||
|
if content_digests:
|
||||||
|
# Add content digest attribute to nodes
|
||||||
|
elements = [add_content_digest(element) for element in elements]
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
def plain_element(element, content_digests, node_indexes):
|
||||||
|
# For lists, we make each item plain text
|
||||||
|
if is_leaf(element):
|
||||||
|
# For leaf node elements, extract the text content, discarding any HTML tags
|
||||||
|
# 1. Get element contents as text
|
||||||
|
plain_text = element.get_text()
|
||||||
|
# 2. Normalise the extracted text string to a canonical representation
|
||||||
|
plain_text = normalise_text(plain_text)
|
||||||
|
# 3. Update element content to be plain text
|
||||||
|
element.string = plain_text
|
||||||
|
elif is_text(element):
|
||||||
|
if is_non_printing(element):
|
||||||
|
# The simplified HTML may have come from Readability.js so might
|
||||||
|
# have non-printing text (e.g. Comment or CData). In this case, we
|
||||||
|
# keep the structure, but ensure that the string is empty.
|
||||||
|
element = type(element)("")
|
||||||
|
else:
|
||||||
|
plain_text = element.string
|
||||||
|
plain_text = normalise_text(plain_text)
|
||||||
|
element = type(element)(plain_text)
|
||||||
|
else:
|
||||||
|
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
||||||
|
element.contents = plain_elements(element.contents, content_digests, node_indexes)
|
||||||
|
return element
|
||||||
|
|
||||||
|
|
||||||
|
def add_node_indexes(element, node_index="0"):
|
||||||
|
# Can't add attributes to string types
|
||||||
|
if is_text(element):
|
||||||
|
return element
|
||||||
|
# Add index to current element
|
||||||
|
element["data-node-index"] = node_index
|
||||||
|
# Add index to child elements
|
||||||
|
for local_idx, child in enumerate(
|
||||||
|
[c for c in element.contents if not is_text(c)], start=1):
|
||||||
|
# Can't add attributes to leaf string types
|
||||||
|
child_index = "{stem}.{local}".format(
|
||||||
|
stem=node_index, local=local_idx)
|
||||||
|
add_node_indexes(child, node_index=child_index)
|
||||||
|
return element
|
||||||
|
|
||||||
|
|
||||||
|
def normalise_text(text):
|
||||||
|
"""Normalise unicode and whitespace."""
|
||||||
|
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
|
||||||
|
text = strip_control_characters(text)
|
||||||
|
text = normalise_unicode(text)
|
||||||
|
text = normalise_whitespace(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def strip_control_characters(text):
|
||||||
|
"""Strip out unicode control characters which might break the parsing."""
|
||||||
|
# Unicode control characters
|
||||||
|
# [Cc]: Other, Control [includes new lines]
|
||||||
|
# [Cf]: Other, Format
|
||||||
|
# [Cn]: Other, Not Assigned
|
||||||
|
# [Co]: Other, Private Use
|
||||||
|
# [Cs]: Other, Surrogate
|
||||||
|
control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
|
||||||
|
retained_chars = ['\t', '\n', '\r', '\f']
|
||||||
|
|
||||||
|
# Remove non-printing control characters
|
||||||
|
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
|
||||||
|
|
||||||
|
|
||||||
|
def normalise_unicode(text):
|
||||||
|
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||||
|
normal_form = "NFKC"
|
||||||
|
text = unicodedata.normalize(normal_form, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def normalise_whitespace(text):
|
||||||
|
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
||||||
|
text = regex.sub(r"\s+", " ", text)
|
||||||
|
# Remove leading and trailing whitespace
|
||||||
|
text = text.strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
def is_leaf(element):
|
||||||
|
return (element.name in ['p', 'li'])
|
||||||
|
|
||||||
|
|
||||||
|
def is_text(element):
|
||||||
|
return isinstance(element, NavigableString)
|
||||||
|
|
||||||
|
|
||||||
|
def is_non_printing(element):
|
||||||
|
return any(isinstance(element, _e) for _e in [Comment, CData])
|
||||||
|
|
||||||
|
|
||||||
|
def add_content_digest(element):
|
||||||
|
if not is_text(element):
|
||||||
|
element["data-content-digest"] = content_digest(element)
|
||||||
|
return element
|
||||||
|
|
||||||
|
|
||||||
|
def content_digest(element):
|
||||||
|
if is_text(element):
|
||||||
|
# Hash
|
||||||
|
trimmed_string = element.string.strip()
|
||||||
|
if trimmed_string == "":
|
||||||
|
digest = ""
|
||||||
|
else:
|
||||||
|
digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
|
||||||
|
else:
|
||||||
|
contents = element.contents
|
||||||
|
num_contents = len(contents)
|
||||||
|
if num_contents == 0:
|
||||||
|
# No hash when no child elements exist
|
||||||
|
digest = ""
|
||||||
|
elif num_contents == 1:
|
||||||
|
# If single child, use digest of child
|
||||||
|
digest = content_digest(contents[0])
|
||||||
|
else:
|
||||||
|
# Build content digest from the "non-empty" digests of child nodes
|
||||||
|
digest = hashlib.sha256()
|
||||||
|
child_digests = list(
|
||||||
|
filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
||||||
|
for child in child_digests:
|
||||||
|
digest.update(child.encode('utf-8'))
|
||||||
|
digest = digest.hexdigest()
|
||||||
|
return digest
|
||||||
@ -0,0 +1,32 @@
|
|||||||
|
"""add is_universal in apps
|
||||||
|
|
||||||
|
Revision ID: 2beac44e5f5f
|
||||||
|
Revises: d3d503a3471c
|
||||||
|
Create Date: 2023-07-07 12:11:29.156057
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '2beac44e5f5f'
|
||||||
|
down_revision = 'a5b56fb053ef'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('apps', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('is_universal', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('apps', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('is_universal')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
"""add tool providers
|
||||||
|
|
||||||
|
Revision ID: 7ce5a52e4eee
|
||||||
|
Revises: 2beac44e5f5f
|
||||||
|
Create Date: 2023-07-10 10:26:50.074515
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '7ce5a52e4eee'
|
||||||
|
down_revision = '2beac44e5f5f'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('tool_providers',
|
||||||
|
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||||
|
sa.Column('tool_name', sa.String(length=40), nullable=False),
|
||||||
|
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
|
||||||
|
sa.Column('is_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||||
|
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('sensitive_word_avoidance')
|
||||||
|
|
||||||
|
op.drop_table('tool_providers')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,47 @@
|
|||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProviderName(Enum):
|
||||||
|
SERPAPI = 'serpapi'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in ToolProviderName:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProvider(db.Model):
|
||||||
|
__tablename__ = 'tool_providers'
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||||
|
db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||||
|
)
|
||||||
|
|
||||||
|
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||||
|
tenant_id = db.Column(UUID, nullable=False)
|
||||||
|
tool_name = db.Column(db.String(40), nullable=False)
|
||||||
|
encrypted_credentials = db.Column(db.Text, nullable=True)
|
||||||
|
is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||||
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def credentials_is_set(self):
|
||||||
|
"""
|
||||||
|
Returns True if the encrypted_config is not None, indicating that the token is set.
|
||||||
|
"""
|
||||||
|
return self.encrypted_credentials is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def credentials(self):
|
||||||
|
"""
|
||||||
|
Returns the decrypted config.
|
||||||
|
"""
|
||||||
|
return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None
|
||||||
Loading…
Reference in New Issue