Feat/assistant app (#2086)
Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: Pascal M <11357019+perzeuss@users.noreply.github.com>pull/2140/head
@ -0,0 +1,26 @@
|
||||
name: Run Tool Pytest
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: ./api/requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install -r ./api/requirements.txt
|
||||
|
||||
- name: Run pytest
|
||||
run: pytest ./api/tests/integration_tests/tools/test_all_provider.py
|
||||
@ -1,120 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError)
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class UniversalChatApi(UniversalChatResource):
|
||||
def post(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
parser.add_argument('tools', type=list, required=True, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
# update app model config
|
||||
args['model_config'] = app_model_config.to_dict()
|
||||
args['model_config']['model']['name'] = args['model']
|
||||
args['model_config']['model']['provider'] = args['provider']
|
||||
args['model_config']['agent_mode']['tools'] = args['tools']
|
||||
|
||||
if not args['model_config']['agent_mode']['tools']:
|
||||
args['model_config']['agent_mode']['tools'] = [
|
||||
{
|
||||
"current_datetime": {
|
||||
"enabled": True
|
||||
}
|
||||
}
|
||||
]
|
||||
else:
|
||||
args['model_config']['agent_mode']['tools'].append({
|
||||
"current_datetime": {
|
||||
"enabled": True
|
||||
}
|
||||
})
|
||||
|
||||
args['inputs'] = {}
|
||||
|
||||
del args['model']
|
||||
del args['tools']
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=True,
|
||||
is_model_config_override=True,
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class UniversalChatStopApi(UniversalChatResource):
|
||||
def post(self, universal_app, task_id):
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
|
||||
api.add_resource(UniversalChatApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatStopApi, '/universal-chat/messages/<string:task_id>/stop')
|
||||
@ -1,110 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from fields.conversation_fields import (conversation_with_model_config_fields,
|
||||
conversation_with_model_config_infinite_scroll_pagination_fields)
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class UniversalChatConversationListApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
pinned = None
|
||||
if 'pinned' in args and args['pinned'] is not None:
|
||||
pinned = True if args['pinned'] == 'true' else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
pinned=pinned
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationApi(UniversalChatResource):
|
||||
def delete(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_with_model_config_fields)
|
||||
def post(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatConversationPinApi(UniversalChatResource):
|
||||
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class UniversalChatConversationUnPinApi(UniversalChatResource):
|
||||
def patch(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatConversationRenameApi, '/universal-chat/conversations/<uuid:c_id>/name')
|
||||
api.add_resource(UniversalChatConversationListApi, '/universal-chat/conversations')
|
||||
api.add_resource(UniversalChatConversationApi, '/universal-chat/conversations/<uuid:c_id>')
|
||||
api.add_resource(UniversalChatConversationPinApi, '/universal-chat/conversations/<uuid:c_id>/pin')
|
||||
api.add_resource(UniversalChatConversationUnPinApi, '/universal-chat/conversations/<uuid:c_id>/unpin')
|
||||
@ -1,145 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError, ProviderQuotaExceededError)
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from flask_login import current_user
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
|
||||
class UniversalChatMessageListApi(UniversalChatResource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
agent_thought_fields = {
|
||||
'id': fields.String,
|
||||
'chain_id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
retriever_resource_fields = {
|
||||
'id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'dataset_id': fields.String,
|
||||
'dataset_name': fields.String,
|
||||
'document_id': fields.String,
|
||||
'document_name': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'segment_id': fields.String,
|
||||
'score': fields.Float,
|
||||
'hit_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'segment_position': fields.Integer,
|
||||
'index_node_hash': fields.String,
|
||||
'content': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField,
|
||||
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, current_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.message.FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class UniversalChatMessageFeedbackApi(UniversalChatResource):
|
||||
def post(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
||||
def get(self, universal_app, message_id):
|
||||
app_model = universal_app
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError()
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {'data': questions}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatMessageListApi, '/universal-chat/messages')
|
||||
api.add_resource(UniversalChatMessageFeedbackApi, '/universal-chat/messages/<uuid:message_id>/feedbacks')
|
||||
api.add_resource(UniversalChatMessageSuggestedQuestionApi, '/universal-chat/messages/<uuid:message_id>/suggested-questions')
|
||||
@ -1,38 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from flask_restful import fields, marshal_with
|
||||
from models.model import App
|
||||
|
||||
|
||||
class UniversalChatParameterApi(UniversalChatResource):
|
||||
"""Resource for app variables."""
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'suggested_questions_after_answer': fields.Raw,
|
||||
'speech_to_text': fields.Raw,
|
||||
'retriever_resource': fields.Raw,
|
||||
'annotation_reply': fields.Raw
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, universal_app: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = universal_app
|
||||
app_model_config = app_model.app_model_config
|
||||
app_model_config.retriever_resource = json.dumps({'enabled': True})
|
||||
|
||||
return {
|
||||
'opening_statement': app_model_config.opening_statement,
|
||||
'suggested_questions': app_model_config.suggested_questions_list,
|
||||
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'annotation_reply': app_model_config.annotation_reply_dict,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(UniversalChatParameterApi, '/universal-chat/parameters')
|
||||
@ -1,86 +0,0 @@
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
def universal_chat_app_required(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# get universal chat app
|
||||
universal_app = db.session.query(App).filter(
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.is_universal == True
|
||||
).first()
|
||||
|
||||
if universal_app is None:
|
||||
# create universal app if not exists
|
||||
universal_app = App(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name='Universal Chat',
|
||||
mode='chat',
|
||||
is_universal=True,
|
||||
icon='',
|
||||
icon_background='',
|
||||
api_rpm=0,
|
||||
api_rph=0,
|
||||
enable_site=False,
|
||||
enable_api=False,
|
||||
status='normal'
|
||||
)
|
||||
|
||||
db.session.add(universal_app)
|
||||
db.session.flush()
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
provider="",
|
||||
model_id="",
|
||||
configs={},
|
||||
opening_statement='',
|
||||
suggested_questions=json.dumps([]),
|
||||
suggested_questions_after_answer=json.dumps({'enabled': True}),
|
||||
speech_to_text=json.dumps({'enabled': True}),
|
||||
retriever_resource=json.dumps({'enabled': True}),
|
||||
more_like_this=None,
|
||||
sensitive_word_avoidance=None,
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-16k",
|
||||
"completion_params": {
|
||||
"max_tokens": 800,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0
|
||||
}
|
||||
}),
|
||||
user_input_form=json.dumps([]),
|
||||
pre_prompt='',
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}),
|
||||
)
|
||||
|
||||
app_model_config.app_id = universal_app.id
|
||||
db.session.add(app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
universal_app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
|
||||
return view(universal_app, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
class UniversalChatResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
method_decorators = [universal_chat_app_required, account_initialization_required, login_required, setup_required]
|
||||
@ -1,136 +1,293 @@
|
||||
import json
|
||||
|
||||
from libs.login import login_required
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask import send_file
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.tool.provider.errors import ToolValidateFailedError
|
||||
from core.tool.provider.tool_provider_service import ToolProviderService
|
||||
from extensions.ext_database import db
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from libs.login import login_required
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from services.tools_manage_service import ToolManageService
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
import io
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_credential_dict = {}
|
||||
for tool_name in ToolProviderName:
|
||||
tool_credential_dict[tool_name.value] = {
|
||||
'tool_name': tool_name.value,
|
||||
'is_enabled': False,
|
||||
'credentials': None
|
||||
}
|
||||
|
||||
tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
for p in tool_providers:
|
||||
if p.is_enabled:
|
||||
tool_credential_dict[p.tool_name] = {
|
||||
'tool_name': p.tool_name,
|
||||
'is_enabled': p.is_enabled,
|
||||
'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
|
||||
}
|
||||
|
||||
return list(tool_credential_dict.values())
|
||||
return ToolManageService.list_tool_providers(user_id, tenant_id)
|
||||
|
||||
class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
class ToolProviderCredentialsApi(Resource):
|
||||
return ToolManageService.list_builtin_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
return ToolManageService.delete_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
|
||||
f'current_role is {current_user.current_tenant.current_role}')
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.update_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
args['credentials'],
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
|
||||
|
||||
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json')
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
args = parser.parse_args()
|
||||
|
||||
encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
|
||||
return ToolManageService.create_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
args['icon'],
|
||||
args['credentials'],
|
||||
args['schema_type'],
|
||||
args['schema'],
|
||||
args.get('privacy_policy', ''),
|
||||
)
|
||||
|
||||
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
tenant = current_user.current_tenant
|
||||
parser.add_argument('url', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
tool_provider_model = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == tenant.id,
|
||||
ToolProvider.tool_name == provider,
|
||||
).first()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Only allow updating token for CUSTOM provider type
|
||||
if tool_provider_model:
|
||||
tool_provider_model.encrypted_credentials = encrypted_credentials
|
||||
tool_provider_model.is_enabled = True
|
||||
else:
|
||||
tool_provider_model = ToolProvider(
|
||||
tenant_id=tenant.id,
|
||||
tool_name=provider,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
is_enabled=True
|
||||
)
|
||||
db.session.add(tool_provider_model)
|
||||
return ToolManageService.get_api_tool_provider_remote_schema(
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
args['url'],
|
||||
)
|
||||
|
||||
class ToolApiProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
db.session.commit()
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
class ToolProviderCredentialsValidateApi(Resource):
|
||||
return ToolManageService.list_api_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolApiProviderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if provider not in [p.value for p in ToolProviderName]:
|
||||
abort(404)
|
||||
def post(self):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('icon', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('privacy_policy', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
return ToolManageService.update_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
args['original_provider'],
|
||||
args['icon'],
|
||||
args['credentials'],
|
||||
args['schema_type'],
|
||||
args['schema'],
|
||||
args['privacy_policy'],
|
||||
)
|
||||
|
||||
class ToolApiProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.delete_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolApiProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
tool_provider_service = ToolProviderService(tenant_id, provider)
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
try:
|
||||
tool_provider_service.credentials_validate(args['credentials'])
|
||||
except ToolValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.get_api_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
return ToolManageService.list_builtin_provider_credentials_schema(provider)
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
return response
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.parser_api_schema(
|
||||
schema=args['schema'],
|
||||
)
|
||||
|
||||
class ToolApiProviderPreviousTestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args['tool_name'],
|
||||
args['credentials'],
|
||||
args['parameters'],
|
||||
args['schema_type'],
|
||||
args['schema'],
|
||||
)
|
||||
|
||||
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
|
||||
api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
|
||||
api.add_resource(ToolProviderCredentialsValidateApi,
|
||||
'/workspaces/current/tool-providers/<provider>/credentials-validate')
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools')
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete')
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
|
||||
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
|
||||
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
|
||||
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
|
||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
|
||||
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
|
||||
api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update')
|
||||
api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete')
|
||||
api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get')
|
||||
api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')
|
||||
api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre')
|
||||
|
||||
@ -0,0 +1,47 @@
|
||||
from controllers.files import api
|
||||
from flask import Response
|
||||
from flask_restful import Resource, reqparse
|
||||
from libs.exception import BaseHTTPException
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
class ToolFilePreviewApi(Resource):
|
||||
def get(self, file_id, extension):
|
||||
file_id = str(file_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument('timestamp', type=str, required=True, location='args')
|
||||
parser.add_argument('nonce', type=str, required=True, location='args')
|
||||
parser.add_argument('sign', type=str, required=True, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not ToolFileManager.verify_file(file_id=file_id,
|
||||
timestamp=args['timestamp'],
|
||||
nonce=args['nonce'],
|
||||
sign=args['sign'],
|
||||
):
|
||||
raise Forbidden('Invalid request.')
|
||||
|
||||
try:
|
||||
result = ToolFileManager.get_file_generator_by_message_file_id(
|
||||
file_id,
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise NotFound(f'file is not found')
|
||||
|
||||
generator, mimetype = result
|
||||
except Exception:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
api.add_resource(ToolFilePreviewApi, '/files/tools/<uuid:file_id>.<string:extension>')
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
@ -1,251 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, PromptTemplateEntity
|
||||
from core.features.agent_runner import AgentRunnerFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentApplicationRunner(AppRunner):
|
||||
"""
|
||||
Agent Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run agent application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=None,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Create MessageChain
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# add agent callback to record agent thoughts
|
||||
agent_callback = AgentLoopGatherCallbackHandler(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message,
|
||||
queue_manager=queue_manager,
|
||||
message_chain=message_chain
|
||||
)
|
||||
|
||||
# init LLM Callback
|
||||
agent_llm_callback = AgentLLMCallback(
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
||||
agent_runner = AgentRunnerFeature(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=app_orchestration_config.agent,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
callback=agent_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# agent run
|
||||
result = agent_runner.run(
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if result:
|
||||
self._save_message_chain(
|
||||
message_chain=message_chain,
|
||||
output_text=result
|
||||
)
|
||||
|
||||
if (result
|
||||
and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
|
||||
and app_orchestration_config.prompt_template.simple_prompt_template
|
||||
):
|
||||
# Direct output if agent result exists and has pre prompt
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
stream=application_generate_entity.stream,
|
||||
text=result,
|
||||
usage=self._get_usage_of_all_agent_thoughts(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message
|
||||
)
|
||||
)
|
||||
else:
|
||||
# As normal LLM run, agent result as context
|
||||
context = result
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_token
|
||||
all_answer_tokens += agent_thought.answer_token
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
)
|
||||
@ -0,0 +1,342 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.features.assistant_cot_runner import AssistantCotApplicationRunner
|
||||
from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
|
||||
AgentEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationException
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
Assistant Application Runner
|
||||
"""
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run assistant application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# moderation
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
if query:
|
||||
# annotation reply
|
||||
annotation_reply = self.query_app_annotations_to_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish_annotation_reply(
|
||||
message_annotation_id=annotation_reply.id,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_orchestration_config.external_data_variables
|
||||
if external_data_tools:
|
||||
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
agent_entity = app_orchestration_config.agent
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tanent_id=application_generate_entity.tenant_id)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
prompt_message, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
assistant_cot_runner = AssistantCotApplicationRunner(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
)
|
||||
invoke_result = assistant_cot_runner.run(
|
||||
model_instance=model_instance,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
)
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
assistant_cot_runner = AssistantFunctionCallApplicationRunner(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=agent_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables
|
||||
)
|
||||
invoke_result = assistant_cot_runner.run(
|
||||
model_instance=model_instance,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tanent_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tanent_id
|
||||
).first()
|
||||
|
||||
if tool_variables:
|
||||
# save tool variables to session, so that we can update it later
|
||||
db.session.add(tool_variables)
|
||||
else:
|
||||
# create new tool variables
|
||||
tool_variables = ToolConversationVariables(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tanent_id,
|
||||
variables_str='[]',
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
db.session.commit()
|
||||
|
||||
return tool_variables
|
||||
|
||||
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
|
||||
"""
|
||||
convert db variables to tool variables
|
||||
"""
|
||||
return ToolRuntimeVariablePool(**{
|
||||
'conversation_id': db_variables.conversation_id,
|
||||
'user_id': db_variables.user_id,
|
||||
'tenant_id': db_variables.tenant_id,
|
||||
'pool': db_variables.variables
|
||||
})
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
)
|
||||
@ -0,0 +1,74 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
|
||||
class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
"""Callback Handler that prints to std out."""
|
||||
color: Optional[str] = ''
|
||||
current_loop = 1
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
"""Initialize callback handler."""
|
||||
# use a specific color is not specified
|
||||
self.color = color or 'green'
|
||||
self.current_loop = 1
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_outputs: str,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
print_text("\n[on_tool_end]\n", color=self.color)
|
||||
print_text("Tool: " + tool_name + "\n", color=self.color)
|
||||
print_text("Inputs: " + str(tool_inputs) + "\n", color=self.color)
|
||||
print_text("Outputs: " + str(tool_outputs) + "\n", color=self.color)
|
||||
print_text("\n")
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red')
|
||||
|
||||
def on_agent_start(
|
||||
self, thought: str
|
||||
) -> None:
|
||||
"""Run on agent start."""
|
||||
if thought:
|
||||
print_text("\n[on_agent_start] \nCurrent Loop: " + \
|
||||
str(self.current_loop) + \
|
||||
"\nThought: " + thought + "\n", color=self.color)
|
||||
else:
|
||||
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||
|
||||
def on_agent_finish(
|
||||
self, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||
|
||||
self.current_loop += 1
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
@ -0,0 +1,558 @@
|
||||
import logging
|
||||
import json
|
||||
|
||||
from typing import Optional, List, Tuple, Union
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from extensions.ext_database import db
|
||||
|
||||
from models.model import MessageAgentThought, Message, MessageFile
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, \
|
||||
ToolRuntimeVariablePool, ToolParamter
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import ModelConfigEntity, AgentEntity, AgentToolEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.entities.application_entities import ModelConfigEntity, \
|
||||
AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.file.message_file_parser import FileTransferMethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAssistantApplicationRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: ApplicationGenerateEntity,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
model_config: ModelConfigEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[List[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param agent_llm_callback: agent llm callback
|
||||
:param callback: callback
|
||||
:param memory: memory
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.app_orchestration_config = app_orchestration_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = prompt_messages
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
|
||||
# init callback
|
||||
self.agent_callback = DifyAgentCallbackHandler()
|
||||
# init dataset tools
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=queue_manager,
|
||||
app_id=self.application_generate_entity.app_id,
|
||||
message_id=message.id,
|
||||
user_id=user_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [],
|
||||
retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None,
|
||||
return_resource=app_orchestration_config.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
|
||||
def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
|
||||
"""
|
||||
Repacket app orchestration config
|
||||
"""
|
||||
if app_orchestration_config.prompt_template.simple_prompt_template is None:
|
||||
app_orchestration_config.prompt_template.simple_prompt_template = ''
|
||||
|
||||
return app_orchestration_config
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
result = ''
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please dirct user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += f"image has been created and sent to user already, you should tell user to check it now."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
|
||||
tanent_id=self.application_generate_entity.tenant_id,
|
||||
agent_callback=self.agent_callback
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
runtime_parameters = {}
|
||||
|
||||
parameters = tool_entity.parameters or []
|
||||
user_parameters = tool_entity.get_runtime_parameters() or []
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
found = False
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
for parameter in parameters:
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParamter.ToolParameterType.STRING:
|
||||
parameter_type = 'string'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
|
||||
parameter_type = 'boolean'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.NUMBER:
|
||||
parameter_type = 'number'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.SELECT:
|
||||
for option in parameter.options:
|
||||
enum.append(option.value)
|
||||
parameter_type = 'string'
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParamter.ToolParameterForm.FORM:
|
||||
# get tool parameter from form
|
||||
tool_parameter_config = tool.tool_parameters.get(parameter.name)
|
||||
if not tool_parameter_config:
|
||||
# get default value
|
||||
tool_parameter_config = parameter.default
|
||||
if not tool_parameter_config and parameter.required:
|
||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
||||
|
||||
if parameter.type == ToolParamter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter.options))
|
||||
if tool_parameter_config not in options:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter.type == ToolParamter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(tool_parameter_config, int):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, float):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, str):
|
||||
if '.' in tool_parameter_config:
|
||||
tool_parameter_config = float(tool_parameter_config)
|
||||
else:
|
||||
tool_parameter_config = int(tool_parameter_config)
|
||||
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
|
||||
tool_parameter_config = bool(tool_parameter_config)
|
||||
elif parameter.type not in [ToolParamter.ToolParameterType.SELECT, ToolParamter.ToolParameterType.STRING]:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
elif parameter.type == ToolParamter.ToolParameterType:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
||||
|
||||
# save tool parameter to tool entity memory
|
||||
runtime_parameters[parameter.name] = tool_parameter_config
|
||||
|
||||
elif parameter.form == ToolParamter.ToolParameterForm.LLM:
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name,
|
||||
description=tool.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
for parameter in tool.get_runtime_parameters():
|
||||
parameter_type = 'string'
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
||||
"""
|
||||
update prompt message tool
|
||||
"""
|
||||
# try to get tool runtime parameters
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParamter.ToolParameterType.STRING:
|
||||
parameter_type = 'string'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.BOOLEAN:
|
||||
parameter_type = 'boolean'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.NUMBER:
|
||||
parameter_type = 'number'
|
||||
elif parameter.type == ToolParamter.ToolParameterType.SELECT:
|
||||
for option in parameter.options:
|
||||
enum.append(option.value)
|
||||
parameter_type = 'string'
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParamter.ToolParameterForm.LLM:
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and 'mime_type' in response.meta:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:param messages: messages
|
||||
:return: message files, should save as variable
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
file_type = 'bin'
|
||||
if 'image' in message.mimetype:
|
||||
file_type = 'image'
|
||||
elif 'video' in message.mimetype:
|
||||
file_type = 'video'
|
||||
elif 'audio' in message.mimetype:
|
||||
file_type = 'audio'
|
||||
elif 'text' in message.mimetype:
|
||||
file_type = 'text'
|
||||
elif 'pdf' in message.mimetype:
|
||||
file_type = 'pdf'
|
||||
elif 'zip' in message.mimetype:
|
||||
file_type = 'archive'
|
||||
# ...
|
||||
|
||||
invoke_from = self.application_generate_entity.invoke_from
|
||||
|
||||
message_file = MessageFile(
|
||||
message_id=self.message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE.value,
|
||||
belongs_to='assistant',
|
||||
url=message.url,
|
||||
upload_file_id=None,
|
||||
created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=self.user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
result.append((
|
||||
message_file,
|
||||
message.save_as
|
||||
))
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: List[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
"""
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
thought='',
|
||||
tool=tool_name,
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_files=json.dumps(messages_ids) if messages_ids else '',
|
||||
answer='',
|
||||
observation='',
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
position=self.agent_thought_count + 1,
|
||||
currency='USD',
|
||||
latency=0,
|
||||
created_by_role='account',
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
db.session.add(thought)
|
||||
db.session.commit()
|
||||
|
||||
self.agent_thought_count += 1
|
||||
|
||||
return thought
|
||||
|
||||
def save_agent_thought(self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: str,
|
||||
answer: str,
|
||||
messages_ids: List[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
|
||||
if tool_name is not None:
|
||||
agent_thought.tool = tool_name
|
||||
|
||||
if tool_input is not None:
|
||||
if isinstance(tool_input, dict):
|
||||
try:
|
||||
tool_input = json.dumps(tool_input, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
tool_input = json.dumps(tool_input)
|
||||
|
||||
agent_thought.tool_input = tool_input
|
||||
|
||||
if observation is not None:
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
agent_thought.answer_token = llm_usage.completion_tokens
|
||||
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages
|
||||
"""
|
||||
if self.history_prompt_messages is None:
|
||||
self.history_prompt_messages = db.session.query(PromptMessage).filter(
|
||||
PromptMessage.message_id == self.message.id,
|
||||
).order_by(PromptMessage.position.asc()).all()
|
||||
|
||||
return self.history_prompt_messages
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message into agent thought
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result.append(message)
|
||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# try to download image
|
||||
try:
|
||||
file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id,
|
||||
conversation_id=self.message.conversation_id,
|
||||
file_url=message.message)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
save_as=message.save_as,
|
||||
))
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
mimetype = message.meta.get('mime_type', 'octet/stream')
|
||||
# if message is str, encode it to bytes
|
||||
if isinstance(message.message, str):
|
||||
message.message = message.message.encode('utf-8')
|
||||
file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id,
|
||||
conversation_id=self.message.conversation_id,
|
||||
file_binary=message.message,
|
||||
mimetype=mimetype)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
|
||||
|
||||
# check if file is image
|
||||
if 'image' in mimetype:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
return result
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables.updated_at = datetime.utcnow()
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
@ -0,0 +1,578 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal, Union, Generator, Dict, List
|
||||
|
||||
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, \
|
||||
UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
|
||||
ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \
|
||||
ToolProviderCredentialValidationError
|
||||
|
||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||
|
||||
from models.model import Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
def run(self, model_instance: ModelInstance,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
self._repacket_app_orchestration_config(app_orchestration_config)
|
||||
|
||||
agent_scratchpad: List[AgentScratchpadUnit] = []
|
||||
|
||||
# check model mode
|
||||
if self.app_orchestration_config.model_config.mode == "completion":
|
||||
# TODO: stop words
|
||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||
app_orchestration_config.model_config.stop.append('Observation')
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1
|
||||
|
||||
prompt_messages = self.history_prompt_messages
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increse_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt messages
|
||||
prompt_messages = self._originze_cot_prompt_messages(
|
||||
mode=app_orchestration_config.model_config.mode,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=prompt_messages_tools,
|
||||
agent_scratchpad=agent_scratchpad,
|
||||
agent_prompt_message=app_orchestration_config.agent.prompt,
|
||||
instruction=app_orchestration_config.prompt_template.simple_prompt_template,
|
||||
input=query
|
||||
)
|
||||
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
llm_result: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
tools=[],
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stream=False,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# check llm result
|
||||
if not llm_result:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
# get scratchpad
|
||||
scratchpad = self._extract_response_scratchpad(llm_result.message.content)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if llm_result.usage:
|
||||
increse_usage(llm_usage, llm_result.usage)
|
||||
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input=scratchpad.action.action_input if scratchpad.action else '',
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
answer=llm_result.message.content,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_result.usage)
|
||||
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# publish agent thought if it's not empty and there is a action
|
||||
if scratchpad.thought and scratchpad.action:
|
||||
# check if final answer
|
||||
if not scratchpad.action.action_name.lower() == "final answer":
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=scratchpad.thought
|
||||
),
|
||||
usage=llm_result.usage,
|
||||
),
|
||||
system_fingerprint=''
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = scratchpad.agent_response or ''
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
final_answer = scratchpad.action.action_input if \
|
||||
isinstance(scratchpad.action.action_input, str) else \
|
||||
json.dumps(scratchpad.action.action_input)
|
||||
except json.JSONDecodeError:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
else:
|
||||
function_call_state = True
|
||||
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = scratchpad.action.action_name
|
||||
tool_call_args = scratchpad.action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
logger.error(f"failed to find tool instance: {tool_call_name}")
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
self.save_agent_thought(agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
thought=None,
|
||||
observation=answer,
|
||||
answer=answer,
|
||||
messages_ids=[])
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
tool_response = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_paramters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
|
||||
)
|
||||
# transform tool response to llm friendly response
|
||||
tool_response = self.transform_tool_invoke_messages(tool_response)
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = self.extract_tool_response_binary(tool_response)
|
||||
# create message file
|
||||
message_files = self.create_message_files(binary_files)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name,
|
||||
value=message_file.id,
|
||||
name=save_as)
|
||||
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Plese check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
error_response = f"there is not a tool named {tool_call_name}"
|
||||
except (
|
||||
ToolParamterValidationError
|
||||
) as e:
|
||||
error_response = f"tool paramters validation error: {e}, please check your tool paramters"
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
|
||||
if error_response:
|
||||
observation = error_response
|
||||
logger.error(error_response)
|
||||
else:
|
||||
observation = self._convert_tool_response_to_str(tool_response)
|
||||
|
||||
# save scratchpad
|
||||
scratchpad.observation = observation
|
||||
scratchpad.agent_response = llm_result.message.content
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_name,
|
||||
tool_input=tool_call_args,
|
||||
thought=None,
|
||||
observation=observation,
|
||||
answer=llm_result.message.content,
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage']
|
||||
),
|
||||
system_fingerprint=''
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
thought=final_answer,
|
||||
observation='',
|
||||
answer=final_answer,
|
||||
messages_ids=[]
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish_message_end(LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage'],
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
|
||||
"""
|
||||
extract response from llm response
|
||||
"""
|
||||
def extra_quotes() -> AgentScratchpadUnit:
|
||||
agent_response = content
|
||||
# try to extract all quotes
|
||||
pattern = re.compile(r'```(.*?)```', re.DOTALL)
|
||||
quotes = pattern.findall(content)
|
||||
|
||||
# try to extract action from end to start
|
||||
for i in range(len(quotes) - 1, 0, -1):
|
||||
"""
|
||||
1. use json load to parse action
|
||||
2. use plain text `Action: xxx` to parse action
|
||||
"""
|
||||
try:
|
||||
action = json.loads(quotes[i].replace('```', ''))
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
)
|
||||
except:
|
||||
# try to parse action from plain text
|
||||
action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
|
||||
action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
|
||||
# delete action from agent response
|
||||
agent_thought = agent_response.replace(quotes[i], '')
|
||||
# remove extra quotes
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=quotes[i],
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name[0],
|
||||
action_input=action_input[0],
|
||||
)
|
||||
)
|
||||
|
||||
def extra_json():
|
||||
agent_response = content
|
||||
# try to extract all json
|
||||
structures, pair_match_stack = [], []
|
||||
started_at, end_at = 0, 0
|
||||
for i in range(len(content)):
|
||||
if content[i] == '{':
|
||||
pair_match_stack.append(i)
|
||||
if len(pair_match_stack) == 1:
|
||||
started_at = i
|
||||
elif content[i] == '}':
|
||||
begin = pair_match_stack.pop()
|
||||
if not pair_match_stack:
|
||||
end_at = i + 1
|
||||
structures.append((content[begin:i+1], (started_at, end_at)))
|
||||
|
||||
# handle the last character
|
||||
if pair_match_stack:
|
||||
end_at = len(content)
|
||||
structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
|
||||
|
||||
for i in range(len(structures), 0, -1):
|
||||
try:
|
||||
json_content, (started_at, end_at) = structures[i - 1]
|
||||
action = json.loads(json_content)
|
||||
action_name = action.get("action")
|
||||
action_input = action.get("action_input")
|
||||
# delete json content from agent response
|
||||
agent_thought = agent_response[:started_at] + agent_response[end_at:]
|
||||
# remove extra quotes like ```(json)*\n\n```
|
||||
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
|
||||
# remove Action: xxx from agent thought
|
||||
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
|
||||
|
||||
if action_name and action_input:
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=agent_thought,
|
||||
action_str=json_content,
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
agent_scratchpad = extra_quotes()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
agent_scratchpad = extra_json()
|
||||
if agent_scratchpad:
|
||||
return agent_scratchpad
|
||||
|
||||
return AgentScratchpadUnit(
|
||||
agent_response=content,
|
||||
thought=content,
|
||||
action_str='',
|
||||
action=None
|
||||
)
|
||||
|
||||
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
):
|
||||
"""
|
||||
check chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid action values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
# parse agent prompt message
|
||||
first_prompt = agent_prompt_message.first_prompt
|
||||
next_iteration = agent_prompt_message.next_iteration
|
||||
|
||||
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
|
||||
raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode")
|
||||
|
||||
# check instruction, tools, and tool_names slots
|
||||
if not first_prompt.find("{{instruction}}") >= 0:
|
||||
raise ValueError("{{instruction}} is required in first_prompt")
|
||||
if not first_prompt.find("{{tools}}") >= 0:
|
||||
raise ValueError("{{tools}} is required in first_prompt")
|
||||
if not first_prompt.find("{{tool_names}}") >= 0:
|
||||
raise ValueError("{{tool_names}} is required in first_prompt")
|
||||
|
||||
if mode == "completion":
|
||||
if not first_prompt.find("{{query}}") >= 0:
|
||||
raise ValueError("{{query}} is required in first_prompt")
|
||||
if not first_prompt.find("{{agent_scratchpad}}") >= 0:
|
||||
raise ValueError("{{agent_scratchpad}} is required in first_prompt")
|
||||
|
||||
if mode == "completion":
|
||||
if not next_iteration.find("{{observation}}") >= 0:
|
||||
raise ValueError("{{observation}} is required in next_iteration")
|
||||
|
||||
def _convert_strachpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
convert agent scratchpad list to str
|
||||
"""
|
||||
next_iteration = self.app_orchestration_config.agent.prompt.next_iteration
|
||||
|
||||
result = ''
|
||||
for scratchpad in agent_scratchpad:
|
||||
result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation) + "\n"
|
||||
|
||||
return result
|
||||
|
||||
def _originze_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
agent_scratchpad: List[AgentScratchpadUnit],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
instruction: str,
|
||||
input: str,
|
||||
) -> List[PromptMessage]:
|
||||
"""
|
||||
originze chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid action values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}}}}
|
||||
```
|
||||
"""
|
||||
|
||||
self._check_cot_prompt_messages(mode, agent_prompt_message)
|
||||
|
||||
# parse agent prompt message
|
||||
first_prompt = agent_prompt_message.first_prompt
|
||||
|
||||
# parse tools
|
||||
tools_str = self._jsonify_tool_prompt_messages(tools)
|
||||
|
||||
# parse tools name
|
||||
tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
|
||||
|
||||
# get system message
|
||||
system_message = first_prompt.replace("{{instruction}}", instruction) \
|
||||
.replace("{{tools}}", tools_str) \
|
||||
.replace("{{tool_names}}", tool_names)
|
||||
|
||||
# originze prompt messages
|
||||
if mode == "chat":
|
||||
# override system message
|
||||
overrided = False
|
||||
prompt_messages = prompt_messages.copy()
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_message.content = system_message
|
||||
overrided = True
|
||||
break
|
||||
|
||||
if not overrided:
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=system_message,
|
||||
))
|
||||
|
||||
# add assistant message
|
||||
if len(agent_scratchpad) > 0:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=agent_scratchpad[-1].thought + "\n" + agent_scratchpad[-1].observation
|
||||
))
|
||||
|
||||
# add user message
|
||||
if len(agent_scratchpad) > 0:
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=input,
|
||||
))
|
||||
|
||||
return prompt_messages
|
||||
elif mode == "completion":
|
||||
# parse agent scratchpad
|
||||
agent_scratchpad_str = self._convert_strachpad_list_to_str(agent_scratchpad)
|
||||
# parse prompt messages
|
||||
return [UserPromptMessage(
|
||||
content=first_prompt.replace("{{instruction}}", instruction)
|
||||
.replace("{{tools}}", tools_str)
|
||||
.replace("{{tool_names}}", tool_names)
|
||||
.replace("{{query}}", input)
|
||||
.replace("{{agent_scratchpad}}", agent_scratchpad_str),
|
||||
)]
|
||||
else:
|
||||
raise ValueError(f"mode {mode} is not supported")
|
||||
|
||||
def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
|
||||
"""
|
||||
jsonify tool prompt messages
|
||||
"""
|
||||
tools = jsonable_encoder(tools)
|
||||
try:
|
||||
return json.dumps(tools, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps(tools)
|
||||
@ -0,0 +1,335 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Union, Generator, Dict, Any, Tuple, List
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
|
||||
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
|
||||
from core.model_manager import ModelInstance
|
||||
from core.application_queue_manager import PublishFrom
|
||||
|
||||
from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
|
||||
ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \
|
||||
ToolProviderCredentialValidationError
|
||||
|
||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
def run(self, model_instance: ModelInstance,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
|
||||
prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_messages = self.history_prompt_messages
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=query,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
# api tool may be deleted
|
||||
continue
|
||||
# save tool entity
|
||||
tool_instances[tool.tool_name] = tool_entity
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
|
||||
# convert dataset tools into ModelRuntime Tool format
|
||||
for dataset_tool in self.dataset_tools:
|
||||
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
agent_thoughts: List[MessageAgentThought] = []
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# recale llm max tokens
|
||||
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ''
|
||||
tool_call_inputs = ''
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
for chunk in chunks:
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
for content in chunk.delta.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += chunk.delta.message.content
|
||||
|
||||
if chunk.delta.usage:
|
||||
increase_usage(llm_usage, chunk.delta.usage)
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage
|
||||
)
|
||||
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
final_answer += response + '\n'
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
logger.error(f"failed to find tool instance: {tool_call_name}")
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}"
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
else:
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
tool_invoke_message = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_paramters=tool_call_args,
|
||||
)
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = self.extract_tool_response_binary(tool_invoke_message)
|
||||
# create message file
|
||||
message_files = self.create_message_files(binary_files)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Plese check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
error_response = f"there is not a tool named {tool_call_name}"
|
||||
except (
|
||||
ToolParamterValidationError
|
||||
) as e:
|
||||
error_response = f"tool paramters validation error: {e}, please check your tool paramters"
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
|
||||
if error_response:
|
||||
observation = error_response
|
||||
logger.error(error_response)
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": error_response
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
else:
|
||||
observation = self._convert_tool_response_to_str(tool_invoke_message)
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": observation
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=None,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_response=tool_response['tool_response'],
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_response['tool_response'],
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt messages
|
||||
if response.strip():
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=response,
|
||||
))
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish_message_end(LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer,
|
||||
),
|
||||
usage=llm_usage['usage'],
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
Check if there is any tool call in llm result chunk
|
||||
"""
|
||||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
json.loads(prompt_message.function.arguments),
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
|
||||
def organize_prompt_messages(self, prompt_template: str,
|
||||
query: str = None,
|
||||
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
||||
prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
"""
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
else:
|
||||
if tool_response:
|
||||
prompt_messages = prompt_messages.copy()
|
||||
prompt_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
@ -0,0 +1,8 @@
|
||||
tool_file_manager = {
|
||||
'manager': None
|
||||
}
|
||||
|
||||
class ToolFileParser:
|
||||
@staticmethod
|
||||
def get_tool_file_manager() -> 'ToolFileManager':
|
||||
return tool_file_manager['manager']
|
||||
@ -0,0 +1,25 @@
|
||||
# Tools
|
||||
|
||||
This module implements built-in tools used in Agent Assistants and Workflows within Dify. You could define and display your own tools in this module, without modifying the frontend logic. This decoupling allows for easier horizontal scaling of Dify's capabilities.
|
||||
|
||||
## Feature Introduction
|
||||
|
||||
The tools provided for Agents and Workflows are currently divided into two categories:
|
||||
- `Built-in Tools` are internally implemented within our product and are hardcoded for use in Agents and Workflows.
|
||||
- `Api-Based Tools` leverage third-party APIs for implementation. You don't need to code to integrate these -- simply provide interface definitions in formats like `OpenAPI` , `Swagger`, or the `OpenAI-plugin` on the front-end.
|
||||
|
||||
### Built-in Tool Providers
|
||||

|
||||
|
||||
### API Tool Providers
|
||||

|
||||
|
||||
## Tool Integration
|
||||
|
||||
To enable developers to build flexible and powerful tools, we provide two guides:
|
||||
|
||||
### [Quick Integration 👈🏻](./docs/en_US/tool_scale_out.md)
|
||||
Quick integration aims at quickly getting you up to speed with tool integration by walking over an example Google Search tool.
|
||||
|
||||
### [Advanced Integration 👈🏻](./docs/en_US/advanced_scale_out.md)
|
||||
Advanced integration will offer a deeper dive into the module interfaces, and explain how to implement more complex capabilities, such as generating images, combining multiple tools, and managing the flow of parameters, images, and files between different tools.
|
||||
@ -0,0 +1,266 @@
|
||||
# Advanced Tool Integration
|
||||
|
||||
Before starting with this advanced guide, please make sure you have a basic understanding of the tool integration process in Dify. Check out [Quick Integration](./tool_scale_out.md) for a quick runthrough.
|
||||
|
||||
## Tool Interface
|
||||
|
||||
We have defined a series of helper methods in the `Tool` class to help developers quickly build more complex tools.
|
||||
|
||||
### Message Return
|
||||
|
||||
Dify supports various message types such as `text`, `link`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces.
|
||||
|
||||
Please note, some parameters in the following interfaces will be introduced in later sections.
|
||||
|
||||
#### Image URL
|
||||
You only need to pass the URL of the image, and Dify will automatically download the image and return it to the user.
|
||||
|
||||
```python
|
||||
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
|
||||
:param image: the url of the image
|
||||
:return: the image message
|
||||
"""
|
||||
```
|
||||
|
||||
#### Link
|
||||
If you need to return a link, you can use the following interface.
|
||||
|
||||
```python
|
||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
:param link: the url of the link
|
||||
:return: the link message
|
||||
"""
|
||||
```
|
||||
|
||||
#### Text
|
||||
If you need to return a text message, you can use the following interface.
|
||||
|
||||
```python
|
||||
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
|
||||
:param text: the text of the message
|
||||
:return: the text message
|
||||
"""
|
||||
```
|
||||
|
||||
#### File BLOB
|
||||
If you need to return the raw data of a file, such as images, audio, video, PPT, Word, Excel, etc., you can use the following interface.
|
||||
|
||||
- `blob` The raw data of the file, of bytes type
|
||||
- `meta` The metadata of the file, if you know the type of the file, it is best to pass a `mime_type`, otherwise Dify will use `octet/stream` as the default type
|
||||
|
||||
```python
|
||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
:param blob: the blob
|
||||
:return: the blob message
|
||||
"""
|
||||
```
|
||||
|
||||
### Shortcut Tools
|
||||
|
||||
In large model applications, we have two common needs:
|
||||
- First, summarize a long text in advance, and then pass the summary content to the LLM to prevent the original text from being too long for the LLM to handle
|
||||
- The content obtained by the tool is a link, and the web page information needs to be crawled before it can be returned to the LLM
|
||||
|
||||
To help developers quickly implement these two needs, we provide the following two shortcut tools.
|
||||
|
||||
#### Text Summary Tool
|
||||
|
||||
This tool takes in an user_id and the text to be summarized, and returns the summarized text. Dify will use the default model of the current workspace to summarize the long text.
|
||||
|
||||
```python
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
"""
|
||||
summary the content
|
||||
|
||||
:param user_id: the user id
|
||||
:param content: the content
|
||||
:return: the summary
|
||||
"""
|
||||
```
|
||||
|
||||
#### Web Page Crawling Tool
|
||||
|
||||
This tool takes in web page link to be crawled and a user_agent (which can be empty), and returns a string containing the information of the web page. The `user_agent` is an optional parameter that can be used to identify the tool. If not passed, Dify will use the default `user_agent`.
|
||||
|
||||
```python
|
||||
def get_url(self, url: str, user_agent: str = None) -> str:
|
||||
"""
|
||||
get url
|
||||
""" the crawled result
|
||||
```
|
||||
|
||||
### Variable Pool
|
||||
|
||||
We have introduced a variable pool in `Tool` to store variables, files, etc. generated during the tool's operation. These variables can be used by other tools during the tool's operation.
|
||||
|
||||
Next, we will use `DallE3` and `Vectorizer.AI` as examples to introduce how to use the variable pool.
|
||||
|
||||
- `DallE3` is an image generation tool that can generate images based on text. Here, we will let `DallE3` generate a logo for a coffee shop
|
||||
- `Vectorizer.AI` is a vector image conversion tool that can convert images into vector images, so that the images can be infinitely enlarged without distortion. Here, we will convert the PNG icon generated by `DallE3` into a vector image, so that it can be truly used by designers.
|
||||
|
||||
#### DallE3
|
||||
First, we use DallE3. After creating the image, we save the image to the variable pool. The code is as follows:
|
||||
|
||||
```python
|
||||
from typing import Any, Dict, List, Union
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from base64 import b64decode
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
)
|
||||
|
||||
# prompt
|
||||
prompt = tool_paramters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
|
||||
# call openapi dalle3
|
||||
response = client.images.generate(
|
||||
prompt=prompt, model='dall-e-3',
|
||||
size='1024x1024', n=1, style='vivid', quality='standard',
|
||||
response_format='b64_json'
|
||||
)
|
||||
|
||||
result = []
|
||||
for image in response.data:
|
||||
# Save all images to the variable pool through the save_as parameter. The variable name is self.VARIABLE_KEY.IMAGE.value. If new images are generated later, they will overwrite the previous images.
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
Note that we used `self.VARIABLE_KEY.IMAGE.value` as the variable name of the image. In order for developers' tools to cooperate with each other, we defined this `KEY`. You can use it freely, or you can choose not to use this `KEY`. Passing a custom KEY is also acceptable.
|
||||
|
||||
#### Vectorizer.AI
|
||||
Next, we use Vectorizer.AI to convert the PNG icon generated by DallE3 into a vector image. Let's go through the functions we defined here. The code is as follows:
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from base64 import b64decode
|
||||
|
||||
class VectorizerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
Tool invocation, the image variable name needs to be passed in from here, so that we can get the image from the variable pool
|
||||
"""
|
||||
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParamter]:
|
||||
"""
|
||||
Override the tool parameter list, we can dynamically generate the parameter list based on the actual situation in the current variable pool, so that the LLM can generate the form based on the parameter list
|
||||
"""
|
||||
|
||||
|
||||
def is_tool_avaliable(self) -> bool:
|
||||
"""
|
||||
Whether the current tool is available, if there is no image in the current variable pool, then we don't need to display this tool, just return False here
|
||||
"""
|
||||
```
|
||||
|
||||
Next, let's implement these three functions
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from base64 import b64decode
|
||||
|
||||
class VectorizerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key_name = self.runtime.credentials.get('api_key_name', None)
|
||||
api_key_value = self.runtime.credentials.get('api_key_value', None)
|
||||
|
||||
if not api_key_name or not api_key_value:
|
||||
raise ToolProviderCredentialValidationError('Please input api key name and value')
|
||||
|
||||
# Get image_id, the definition of image_id can be found in get_runtime_parameters
|
||||
image_id = tool_paramters.get('image_id', '')
|
||||
if not image_id:
|
||||
return self.create_text_message('Please input image id')
|
||||
|
||||
# Get the image generated by DallE from the variable pool
|
||||
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
|
||||
if not image_binary:
|
||||
return self.create_text_message('Image not found, please request user to generate image firstly.')
|
||||
|
||||
# Generate vector image
|
||||
response = post(
|
||||
'https://vectorizer.ai/api/v1/vectorize',
|
||||
files={ 'image': image_binary },
|
||||
data={ 'mode': 'test' },
|
||||
auth=(api_key_name, api_key_value),
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
return [
|
||||
self.create_text_message('the vectorized svg is saved as an image.'),
|
||||
self.create_blob_message(blob=response.content,
|
||||
meta={'mime_type': 'image/svg+xml'})
|
||||
]
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParamter]:
|
||||
"""
|
||||
override the runtime parameters
|
||||
"""
|
||||
# Here, we override the tool parameter list, define the image_id, and set its option list to all images in the current variable pool. The configuration here is consistent with the configuration in yaml.
|
||||
return [
|
||||
ToolParamter.get_simple_instance(
|
||||
name='image_id',
|
||||
llm_description=f'the image id that you want to vectorize, \
|
||||
and the image id should be specified in \
|
||||
{[i.name for i in self.list_default_image_variables()]}',
|
||||
type=ToolParamter.ToolParameterType.SELECT,
|
||||
required=True,
|
||||
options=[i.name for i in self.list_default_image_variables()]
|
||||
)
|
||||
]
|
||||
|
||||
def is_tool_avaliable(self) -> bool:
|
||||
# Only when there are images in the variable pool, the LLM needs to use this tool
|
||||
return len(self.list_default_image_variables()) > 0
|
||||
```
|
||||
|
||||
It's worth noting that we didn't actually use `image_id` here. We assumed that there must be an image in the default variable pool when calling this tool, so we directly used `image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)` to get the image. In cases where the model's capabilities are weak, we recommend developers to do the same, which can effectively improve fault tolerance and avoid the model passing incorrect parameters.
|
||||
@ -0,0 +1,212 @@
|
||||
# Quick Tool Integration
|
||||
|
||||
Here, we will use GoogleSearch as an example to demonstrate how to quickly integrate a tool.
|
||||
|
||||
## 1. Prepare the Tool Provider yaml
|
||||
|
||||
### Introduction
|
||||
This yaml declares a new tool provider, and includes information like the provider's name, icon, author, and other details that are fetched by the frontend for display.
|
||||
|
||||
### Example
|
||||
|
||||
We need to create a `google` module (folder) under `core/tools/provider/builtin`, and create `google.yaml`. The name must be consistent with the module name.
|
||||
|
||||
Subsequently, all operations related to this tool will be carried out under this module.
|
||||
|
||||
```yaml
|
||||
identity: # Basic information of the tool provider
|
||||
author: Dify # Author
|
||||
name: google # Name, unique, no duplication with other providers
|
||||
label: # Label for frontend display
|
||||
en_US: Google # English label
|
||||
zh_Hans: Google # Chinese label
|
||||
description: # Description for frontend display
|
||||
en_US: Google # English description
|
||||
zh_Hans: Google # Chinese description
|
||||
icon: icon.svg # Icon, needs to be placed in the _assets folder of the current module
|
||||
|
||||
```
|
||||
- The `identity` field is mandatory, it contains the basic information of the tool provider, including author, name, label, description, icon, etc.
|
||||
- The icon needs to be placed in the `_assets` folder of the current module, you can refer to [here](../../provider/builtin/google/_assets/icon.svg).
|
||||
|
||||
## 2. Prepare Provider Credentials
|
||||
|
||||
Google, as a third-party tool, uses the API provided by SerpApi, which requires an API Key to use. This means that this tool needs a credential to use. For tools like `wikipedia`, there is no need to fill in the credential field, you can refer to [here](../../provider/builtin/wikipedia/wikipedia.yaml).
|
||||
|
||||
After configuring the credential field, the effect is as follows:
|
||||
```yaml
|
||||
identity:
|
||||
author: Dify
|
||||
name: google
|
||||
label:
|
||||
en_US: Google
|
||||
zh_Hans: Google
|
||||
description:
|
||||
en_US: Google
|
||||
zh_Hans: Google
|
||||
icon: icon.svg
|
||||
credentails_for_provider: # Credential field
|
||||
serpapi_api_key: # Credential field name
|
||||
type: secret-input # Credential field type
|
||||
required: true # Required or not
|
||||
label: # Credential field label
|
||||
en_US: SerpApi API key # English label
|
||||
zh_Hans: SerpApi API key # Chinese label
|
||||
placeholder: # Credential field placeholder
|
||||
en_US: Please input your SerpApi API key # English placeholder
|
||||
zh_Hans: 请输入你的 SerpApi API key # Chinese placeholder
|
||||
help: # Credential field help text
|
||||
en_US: Get your SerpApi API key from SerpApi # English help text
|
||||
zh_Hans: 从 SerpApi 获取您的 SerpApi API key # Chinese help text
|
||||
url: https://serpapi.com/manage-api-key # Credential field help link
|
||||
|
||||
```
|
||||
|
||||
- `type`: Credential field type, currently can be either `secret-input`, `text-input`, or `select` , corresponding to password input box, text input box, and drop-down box, respectively. If set to `secret-input`, it will mask the input content on the frontend, and the backend will encrypt the input content.
|
||||
|
||||
## 3. Prepare Tool yaml
|
||||
A provider can have multiple tools, each tool needs a yaml file to describe, this file contains the basic information, parameters, output, etc. of the tool.
|
||||
|
||||
Still taking GoogleSearch as an example, we need to create a `tools` module under the `google` module, and create `tools/google_search.yaml`, the content is as follows.
|
||||
|
||||
```yaml
|
||||
identity: # Basic information of the tool
|
||||
name: google_search # Tool name, unique, no duplication with other tools
|
||||
author: Dify # Author
|
||||
label: # Label for frontend display
|
||||
en_US: GoogleSearch # English label
|
||||
zh_Hans: 谷歌搜索 # Chinese label
|
||||
description: # Description for frontend display
|
||||
human: # Introduction for frontend display, supports multiple languages
|
||||
en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
|
||||
llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. # Introduction passed to LLM, in order to make LLM better understand this tool, we suggest to write as detailed information about this tool as possible here, so that LLM can understand and use this tool
|
||||
parameters: # Parameter list
|
||||
- name: query # Parameter name
|
||||
type: string # Parameter type
|
||||
required: true # Required or not
|
||||
label: # Parameter label
|
||||
en_US: Query string # English label
|
||||
zh_Hans: 查询语句 # Chinese label
|
||||
human_description: # Introduction for frontend display, supports multiple languages
|
||||
en_US: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
llm_description: key words for searching # Introduction passed to LLM, similarly, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
|
||||
form: llm # Form type, llm means this parameter needs to be inferred by Agent, the frontend will not display this parameter
|
||||
- name: result_type
|
||||
type: select # Parameter type
|
||||
required: true
|
||||
options: # Drop-down box options
|
||||
- value: text
|
||||
label:
|
||||
en_US: text
|
||||
zh_Hans: 文本
|
||||
- value: link
|
||||
label:
|
||||
en_US: link
|
||||
zh_Hans: 链接
|
||||
default: link
|
||||
label:
|
||||
en_US: Result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: used for selecting the result type, text or link
|
||||
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
|
||||
form: form # Form type, form means this parameter needs to be filled in by the user on the frontend before the conversation starts
|
||||
|
||||
```
|
||||
|
||||
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
|
||||
- `parameters` Parameter list
|
||||
- `name` Parameter name, unique, no duplication with other parameters
|
||||
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box
|
||||
- `required` Required or not
|
||||
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
|
||||
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
|
||||
- `options` Parameter options
|
||||
- In `llm` mode, Dify will pass all options to LLM, LLM can infer based on these options
|
||||
- In `form` mode, when `type` is `select`, the frontend will display these options
|
||||
- `default` Default value
|
||||
- `label` Parameter label, for frontend display
|
||||
- `human_description` Introduction for frontend display, supports multiple languages
|
||||
- `llm_description` Introduction passed to LLM, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
|
||||
- `form` Form type, currently supports `llm`, `form` two types, corresponding to Agent self-inference and frontend filling
|
||||
|
||||
## 4. Add Tool Logic
|
||||
After completing the tool configuration, we can start writing the tool code that defines how it is invoked.
|
||||
|
||||
Create `google_search.py` under the `google/tools` module, the content is as follows.
|
||||
|
||||
```python
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class GoogleSearchTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
query = tool_paramters['query']
|
||||
result_type = tool_paramters['result_type']
|
||||
api_key = self.runtime.credentials['serpapi_api_key']
|
||||
# TODO: search with serpapi
|
||||
result = SerpAPI(api_key).run(query, result_type=result_type)
|
||||
|
||||
if result_type == 'text':
|
||||
return self.create_text_message(text=result)
|
||||
return self.create_link_message(link=result)
|
||||
```
|
||||
|
||||
### Parameters
|
||||
The overall logic of the tool is in the `_invoke` method, this method accepts two parameters: `user_id` and `tool_paramters`, which represent the user ID and tool parameters respectively
|
||||
|
||||
### Return Data
|
||||
When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message.
|
||||
|
||||
## 5. Add Provider Code
|
||||
Finally, we need to create a provider class under the provider module to implement the provider's credential verification logic. If the credential verification fails, it will throw a `ToolProviderCredentialValidationError` exception.
|
||||
|
||||
Create `google.py` under the `google` module, the content is as follows.
|
||||
|
||||
```python
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
# 1. Here you need to instantiate a GoogleSearchTool with GoogleSearchTool(), it will automatically load the yaml configuration of GoogleSearchTool, but at this time it does not have credential information inside
|
||||
# 2. Then you need to use the fork_tool_runtime method to pass the current credential information to GoogleSearchTool
|
||||
# 3. Finally, invoke it, the parameters need to be passed according to the parameter rules configured in the yaml of GoogleSearchTool
|
||||
GoogleSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"query": "test",
|
||||
"result_type": "link"
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
```
|
||||
|
||||
## Completion
|
||||
After the above steps are completed, we can see this tool on the frontend, and it can be used in the Agent.
|
||||
|
||||
Of course, because google_search needs a credential, before using it, you also need to input your credentials on the frontend.
|
||||
|
||||

|
||||
|
After Width: | Height: | Size: 242 KiB |
|
After Width: | Height: | Size: 407 KiB |
|
After Width: | Height: | Size: 266 KiB |
@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
zh_Hans: Optional[str] = None
|
||||
en_US: str
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
if not self.zh_Hans:
|
||||
self.zh_Hans = self.en_US
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'zh_Hans': self.zh_Hans,
|
||||
'en_US': self.en_US,
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
class DEFAULT_PROVIDERS:
|
||||
API_BASED = '__api_based'
|
||||
APP_BASED = '__app_based'
|
||||
@ -0,0 +1,34 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Optional, Any, List
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType, ToolParamter
|
||||
|
||||
class ApiBasedToolBundle(BaseModel):
|
||||
"""
|
||||
This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc.
|
||||
"""
|
||||
# server_url
|
||||
server_url: str
|
||||
# method
|
||||
method: str
|
||||
# summary
|
||||
summary: Optional[str] = None
|
||||
# operation_id
|
||||
operation_id: str = None
|
||||
# parameters
|
||||
parameters: Optional[List[ToolParamter]] = None
|
||||
# author
|
||||
author: str
|
||||
# icon
|
||||
icon: Optional[str] = None
|
||||
# openapi operation
|
||||
openapi: dict
|
||||
|
||||
class AppToolBundle(BaseModel):
|
||||
"""
|
||||
This class is used to store the schema information of an tool for an app.
|
||||
"""
|
||||
type: ToolProviderType
|
||||
credential: Optional[Dict[str, Any]] = None
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
@ -0,0 +1,305 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any, Union, cast
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
class ToolProviderType(Enum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
"""
|
||||
BUILT_IN = "built-in"
|
||||
APP_BASED = "app-based"
|
||||
API_BASED = "api-based"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'ToolProviderType':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
|
||||
class ApiProviderSchemaType(Enum):
|
||||
"""
|
||||
Enum class for api provider schema type.
|
||||
"""
|
||||
OPENAPI = "openapi"
|
||||
SWAGGER = "swagger"
|
||||
OPENAI_PLUGIN = "openai_plugin"
|
||||
OPENAI_ACTIONS = "openai_actions"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'ApiProviderSchemaType':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
|
||||
class ApiProviderAuthType(Enum):
|
||||
"""
|
||||
Enum class for api provider auth type.
|
||||
"""
|
||||
NONE = "none"
|
||||
API_KEY = "api_key"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'ApiProviderAuthType':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
|
||||
class ToolInvokeMessage(BaseModel):
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
LINK = "link"
|
||||
BLOB = "blob"
|
||||
IMAGE_LINK = "image_link"
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: Union[str, bytes] = None
|
||||
meta: Dict[str, Any] = None
|
||||
save_as: str = ''
|
||||
|
||||
class ToolInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
save_as: str = ''
|
||||
|
||||
class ToolParamterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
class ToolParamter(BaseModel):
|
||||
class ToolParameterType(Enum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
human_description: I18nObject = Field(..., description="The description presented to the user")
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
required: Optional[bool] = False
|
||||
default: Optional[str] = None
|
||||
min: Optional[Union[float, int]] = None
|
||||
max: Optional[Union[float, int]] = None
|
||||
options: Optional[List[ToolParamterOption]] = None
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(cls,
|
||||
name: str, llm_description: str, type: ToolParameterType,
|
||||
required: bool, options: Optional[List[str]] = None) -> 'ToolParamter':
|
||||
"""
|
||||
get a simple tool parameter
|
||||
|
||||
:param name: the name of the parameter
|
||||
:param llm_description: the description presented to the LLM
|
||||
:param type: the type of the parameter
|
||||
:param required: if the parameter is required
|
||||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParamterOption
|
||||
if options:
|
||||
options = [ToolParamterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options]
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US='', zh_Hans=''),
|
||||
human_description=I18nObject(en_US='', zh_Hans=''),
|
||||
type=type,
|
||||
form=cls.ToolParameterForm.LLM,
|
||||
llm_description=llm_description,
|
||||
required=required,
|
||||
options=options,
|
||||
)
|
||||
|
||||
class ToolProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
|
||||
class ToolDescription(BaseModel):
|
||||
human: I18nObject = Field(..., description="The description presented to the user")
|
||||
llm: str = Field(..., description="The description presented to the LLM")
|
||||
|
||||
class ToolIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
|
||||
class ToolCredentialsOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
class ToolProviderCredentials(BaseModel):
|
||||
class CredentialsType(Enum):
|
||||
SECRET_INPUT = "secret-input"
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
|
||||
@staticmethod
|
||||
def defaut(value: str) -> str:
|
||||
return ""
|
||||
|
||||
name: str = Field(..., description="The name of the credentials")
|
||||
type: CredentialsType = Field(..., description="The type of the credentials")
|
||||
required: bool = False
|
||||
default: Optional[str] = None
|
||||
options: Optional[List[ToolCredentialsOption]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
url: Optional[str] = None
|
||||
placeholder: Optional[I18nObject] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'name': self.name,
|
||||
'type': self.type.value,
|
||||
'required': self.required,
|
||||
'default': self.default,
|
||||
'options': self.options,
|
||||
'help': self.help.to_dict() if self.help else None,
|
||||
'label': self.label.to_dict(),
|
||||
'url': self.url,
|
||||
'placeholder': self.placeholder.to_dict() if self.placeholder else None,
|
||||
}
|
||||
|
||||
class ToolRuntimeVariableType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
||||
class ToolRuntimeVariable(BaseModel):
|
||||
type: ToolRuntimeVariableType = Field(..., description="The type of the variable")
|
||||
name: str = Field(..., description="The name of the variable")
|
||||
position: int = Field(..., description="The position of the variable")
|
||||
tool_name: str = Field(..., description="The name of the tool")
|
||||
|
||||
class ToolRuntimeTextVariable(ToolRuntimeVariable):
|
||||
value: str = Field(..., description="The value of the variable")
|
||||
|
||||
class ToolRuntimeImageVariable(ToolRuntimeVariable):
|
||||
value: str = Field(..., description="The path of the image")
|
||||
|
||||
class ToolRuntimeVariablePool(BaseModel):
|
||||
conversation_id: str = Field(..., description="The conversation id")
|
||||
user_id: str = Field(..., description="The user id")
|
||||
tenant_id: str = Field(..., description="The tenant id of assistant")
|
||||
|
||||
pool: List[ToolRuntimeVariable] = Field(..., description="The pool of variables")
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
pool = data.get('pool', [])
|
||||
# convert pool into correct type
|
||||
for index, variable in enumerate(pool):
|
||||
if variable['type'] == ToolRuntimeVariableType.TEXT.value:
|
||||
pool[index] = ToolRuntimeTextVariable(**variable)
|
||||
elif variable['type'] == ToolRuntimeVariableType.IMAGE.value:
|
||||
pool[index] = ToolRuntimeImageVariable(**variable)
|
||||
super().__init__(**data)
|
||||
|
||||
def dict(self) -> dict:
|
||||
return {
|
||||
'conversation_id': self.conversation_id,
|
||||
'user_id': self.user_id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'pool': [variable.dict() for variable in self.pool],
|
||||
}
|
||||
|
||||
def set_text(self, tool_name: str, name: str, value: str) -> None:
|
||||
"""
|
||||
set a text variable
|
||||
"""
|
||||
for variable in self.pool:
|
||||
if variable.name == name:
|
||||
if variable.type == ToolRuntimeVariableType.TEXT:
|
||||
variable = cast(ToolRuntimeTextVariable, variable)
|
||||
variable.value = value
|
||||
return
|
||||
|
||||
variable = ToolRuntimeTextVariable(
|
||||
type=ToolRuntimeVariableType.TEXT,
|
||||
name=name,
|
||||
position=len(self.pool),
|
||||
tool_name=tool_name,
|
||||
value=value,
|
||||
)
|
||||
|
||||
self.pool.append(variable)
|
||||
|
||||
def set_file(self, tool_name: str, value: str, name: str = None) -> None:
|
||||
"""
|
||||
set an image variable
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:param value: the id of the file
|
||||
"""
|
||||
# check how many image variables are there
|
||||
image_variable_count = 0
|
||||
for variable in self.pool:
|
||||
if variable.type == ToolRuntimeVariableType.IMAGE:
|
||||
image_variable_count += 1
|
||||
|
||||
if name is None:
|
||||
name = f"file_{image_variable_count}"
|
||||
|
||||
for variable in self.pool:
|
||||
if variable.name == name:
|
||||
if variable.type == ToolRuntimeVariableType.IMAGE:
|
||||
variable = cast(ToolRuntimeImageVariable, variable)
|
||||
variable.value = value
|
||||
return
|
||||
|
||||
variable = ToolRuntimeImageVariable(
|
||||
type=ToolRuntimeVariableType.IMAGE,
|
||||
name=name,
|
||||
position=len(self.pool),
|
||||
tool_name=tool_name,
|
||||
value=value,
|
||||
)
|
||||
|
||||
self.pool.append(variable)
|
||||
@ -0,0 +1,48 @@
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentials
|
||||
from core.tools.tool.tool import ToolParamter
|
||||
|
||||
class UserToolProvider(BaseModel):
|
||||
class ProviderType(Enum):
|
||||
BUILTIN = "builtin"
|
||||
APP = "app"
|
||||
API = "api"
|
||||
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str
|
||||
label: I18nObject # label
|
||||
type: ProviderType
|
||||
team_credentials: dict = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'id': self.id,
|
||||
'author': self.author,
|
||||
'name': self.name,
|
||||
'description': self.description.to_dict(),
|
||||
'icon': self.icon,
|
||||
'label': self.label.to_dict(),
|
||||
'type': self.type.value,
|
||||
'team_credentials': self.team_credentials,
|
||||
'is_team_authorization': self.is_team_authorization,
|
||||
'allow_delete': self.allow_delete
|
||||
}
|
||||
|
||||
class UserToolProviderCredentials(BaseModel):
|
||||
credentails: Dict[str, ToolProviderCredentials]
|
||||
|
||||
class UserTool(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[List[ToolParamter]]
|
||||
@ -0,0 +1,20 @@
|
||||
class ToolProviderNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
class ToolNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
class ToolParamterValidationError(ValueError):
|
||||
pass
|
||||
|
||||
class ToolProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
|
||||
class ToolNotSupportedError(ValueError):
|
||||
pass
|
||||
|
||||
class ToolInvokeError(ValueError):
|
||||
pass
|
||||
|
||||
class ToolApiSchemaError(ValueError):
|
||||
pass
|
||||
@ -0,0 +1,2 @@
|
||||
class InvokeModelError(Exception):
|
||||
pass
|
||||
@ -0,0 +1,174 @@
|
||||
"""
|
||||
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
|
||||
|
||||
Therefore, a model manager is needed to list/invoke/validate models.
|
||||
"""
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeRateLimitError, InvokeBadRequestError, \
|
||||
InvokeConnectionError, InvokeAuthorizationError, InvokeServerUnavailableError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.model_manager import ModelManager
|
||||
|
||||
from core.tools.model.errors import InvokeModelError
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
from typing import List, cast
|
||||
import json
|
||||
|
||||
class ToolModelManager:
|
||||
@staticmethod
|
||||
def get_max_llm_context_tokens(
|
||||
tenant_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
get max llm context tokens of the model
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError(f'Model not found')
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
if not schema:
|
||||
raise InvokeModelError(f'No model schema found')
|
||||
|
||||
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
if max_tokens is None:
|
||||
return 2048
|
||||
|
||||
return max_tokens
|
||||
|
||||
@staticmethod
|
||||
def calculate_tokens(
|
||||
tenant_id: str,
|
||||
prompt_messages: List[PromptMessage]
|
||||
) -> int:
|
||||
"""
|
||||
calculate tokens from prompt messages and model parameters
|
||||
"""
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError(f'Model not found')
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
|
||||
# get tokens
|
||||
tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
|
||||
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str,
|
||||
tool_type: str, tool_name: str,
|
||||
prompt_messages: List[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id, the tenant id of the creator of the tool
|
||||
:param tool_provider: tool provider
|
||||
:param tool_id: tool id
|
||||
:param tool_name: tool name
|
||||
:param provider: model provider
|
||||
:param model: model name
|
||||
:param model_parameters: model parameters
|
||||
:param prompt_messages: prompt messages
|
||||
:return: AssistantPromptMessage
|
||||
"""
|
||||
|
||||
# get model manager
|
||||
model_manager = ModelManager()
|
||||
# get model instance
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
|
||||
# get model credentials
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# get prompt tokens
|
||||
prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
|
||||
|
||||
model_parameters = {
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.8,
|
||||
}
|
||||
|
||||
# create tool model invoke
|
||||
tool_model_invoke = ToolModelInvoke(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
tool_type=tool_type,
|
||||
tool_name=tool_name,
|
||||
model_parameters=json.dumps(model_parameters),
|
||||
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
|
||||
model_response='',
|
||||
prompt_tokens=prompt_tokens,
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency='USD',
|
||||
)
|
||||
|
||||
db.session.add(tool_model_invoke)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response: LLMResult = llm_model.invoke(
|
||||
model=model_instance.model,
|
||||
credentials=model_credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[], stop=[], stream=False, user=user_id, callbacks=[]
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
raise InvokeModelError(f'Invoke rate limit error: {e}')
|
||||
except InvokeBadRequestError as e:
|
||||
raise InvokeModelError(f'Invoke bad request error: {e}')
|
||||
except InvokeConnectionError as e:
|
||||
raise InvokeModelError(f'Invoke connection error: {e}')
|
||||
except InvokeAuthorizationError as e:
|
||||
raise InvokeModelError(f'Invoke authorization error')
|
||||
except InvokeServerUnavailableError as e:
|
||||
raise InvokeModelError(f'Invoke server unavailable error: {e}')
|
||||
except Exception as e:
|
||||
raise InvokeModelError(f'Invoke error: {e}')
|
||||
|
||||
# update tool model invoke
|
||||
tool_model_invoke.model_response = response.message.content
|
||||
if response.usage:
|
||||
tool_model_invoke.answer_tokens = response.usage.completion_tokens
|
||||
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
|
||||
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
|
||||
tool_model_invoke.provider_response_latency = response.usage.latency
|
||||
tool_model_invoke.total_price = response.usage.total_price
|
||||
tool_model_invoke.currency = response.usage.currency
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return response
|
||||
@ -0,0 +1,102 @@
|
||||
ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {{query}}
|
||||
Thought: {{agent_scratchpad}}"""
|
||||
|
||||
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
|
||||
Thought:"""
|
||||
|
||||
ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
{{instruction}}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $ACTION_INPUT
|
||||
}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
"""
|
||||
|
||||
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
||||
|
||||
REACT_PROMPT_TEMPLATES = {
|
||||
'english': {
|
||||
'chat': {
|
||||
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
|
||||
},
|
||||
'completion': {
|
||||
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,169 @@
|
||||
from typing import Any, Dict, List
|
||||
from core.tools.entities.tool_entities import ToolProviderType, ApiProviderAuthType, ToolProviderCredentials, ToolCredentialsOption
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
class ApiBasedToolProviderController(ToolProviderController):
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
|
||||
credentials_schema = {
|
||||
'auth_type': ToolProviderCredentials(
|
||||
name='auth_type',
|
||||
required=True,
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
options=[
|
||||
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
|
||||
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
|
||||
],
|
||||
default='none',
|
||||
help=I18nObject(
|
||||
en_US='The auth type of the api provider',
|
||||
zh_Hans='api provider 的认证类型'
|
||||
)
|
||||
)
|
||||
}
|
||||
if auth_type == ApiProviderAuthType.API_KEY:
|
||||
credentials_schema = {
|
||||
**credentials_schema,
|
||||
'api_key_header': ToolProviderCredentials(
|
||||
name='api_key_header',
|
||||
required=False,
|
||||
default='api_key',
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
help=I18nObject(
|
||||
en_US='The header name of the api key',
|
||||
zh_Hans='携带 api key 的 header 名称'
|
||||
)
|
||||
),
|
||||
'api_key_value': ToolProviderCredentials(
|
||||
name='api_key_value',
|
||||
required=True,
|
||||
type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
|
||||
help=I18nObject(
|
||||
en_US='The api key',
|
||||
zh_Hans='api key的值'
|
||||
)
|
||||
)
|
||||
}
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'invalid auth type {auth_type}')
|
||||
|
||||
return ApiBasedToolProviderController(**{
|
||||
'identity': {
|
||||
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
||||
'name': db_provider.name,
|
||||
'label': {
|
||||
'en_US': db_provider.name,
|
||||
'zh_Hans': db_provider.name
|
||||
},
|
||||
'description': {
|
||||
'en_US': db_provider.description,
|
||||
'zh_Hans': db_provider.description
|
||||
},
|
||||
'icon': db_provider.icon
|
||||
},
|
||||
'credentials_schema': credentials_schema
|
||||
})
|
||||
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API_BASED
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool:
|
||||
"""
|
||||
parse tool bundle to tool
|
||||
|
||||
:param tool_bundle: the tool bundle
|
||||
:return: the tool
|
||||
"""
|
||||
return ApiTool(**{
|
||||
'api_bundle': tool_bundle,
|
||||
'identity' : {
|
||||
'author': tool_bundle.author,
|
||||
'name': tool_bundle.operation_id,
|
||||
'label': {
|
||||
'en_US': tool_bundle.operation_id,
|
||||
'zh_Hans': tool_bundle.operation_id
|
||||
},
|
||||
'icon': tool_bundle.icon if tool_bundle.icon else ''
|
||||
},
|
||||
'description': {
|
||||
'human': {
|
||||
'en_US': tool_bundle.summary or '',
|
||||
'zh_Hans': tool_bundle.summary or ''
|
||||
},
|
||||
'llm': tool_bundle.summary or ''
|
||||
},
|
||||
'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
|
||||
})
|
||||
|
||||
def load_bundled_tools(self, tools: List[ApiBasedToolBundle]) -> List[ApiTool]:
|
||||
"""
|
||||
load bundled tools
|
||||
|
||||
:param tools: the bundled tools
|
||||
:return: the tools
|
||||
"""
|
||||
self.tools = [self._parse_tool_bundle(tool) for tool in tools]
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tools(self, user_id: str, tanent_id: str) -> List[ApiTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
:param user_id: the user id
|
||||
:param tanent_id: the tanent id
|
||||
:return: the tools
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
tools: List[Tool] = []
|
||||
|
||||
# get tanent api providers
|
||||
db_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tanent_id,
|
||||
ApiToolProvider.name == self.identity.name
|
||||
).all()
|
||||
|
||||
if db_providers and len(db_providers) != 0:
|
||||
for db_provider in db_providers:
|
||||
for tool in db_provider.tools:
|
||||
assistant_tool = self._parse_tool_bundle(tool)
|
||||
assistant_tool.is_team_authorization = True
|
||||
tools.append(assistant_tool)
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> ApiTool:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
self.get_tools()
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
raise ValueError(f'tool {tool_name} not found')
|
||||
@ -0,0 +1,116 @@
|
||||
from typing import Any, Dict, List
|
||||
from core.tools.entities.tool_entities import ToolProviderType, ToolParamter, ToolParamterOption
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.tools import PublishedAppTool
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AppBasedToolProviderEntity(ToolProviderController):
|
||||
@property
|
||||
def app_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP_BASED
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def get_tools(self, user_id: str) -> List[Tool]:
|
||||
db_tools: List[PublishedAppTool] = db.session.query(PublishedAppTool).filter(
|
||||
PublishedAppTool.user_id == user_id,
|
||||
).all()
|
||||
|
||||
if not db_tools or len(db_tools) == 0:
|
||||
return []
|
||||
|
||||
tools: List[Tool] = []
|
||||
|
||||
for db_tool in db_tools:
|
||||
tool = {
|
||||
'identity': {
|
||||
'author': db_tool.author,
|
||||
'name': db_tool.tool_name,
|
||||
'label': {
|
||||
'en_US': db_tool.tool_name,
|
||||
'zh_Hans': db_tool.tool_name
|
||||
},
|
||||
'icon': ''
|
||||
},
|
||||
'description': {
|
||||
'human': {
|
||||
'en_US': db_tool.description_i18n.en_US,
|
||||
'zh_Hans': db_tool.description_i18n.zh_Hans
|
||||
},
|
||||
'llm': db_tool.llm_description
|
||||
},
|
||||
'parameters': []
|
||||
}
|
||||
# get app from db
|
||||
app: App = db_tool.app
|
||||
|
||||
if not app:
|
||||
logger.error(f"app {db_tool.app_id} not found")
|
||||
continue
|
||||
|
||||
app_model_config: AppModelConfig = app.app_model_config
|
||||
user_input_form_list = app_model_config.user_input_form_list
|
||||
for input_form in user_input_form_list:
|
||||
# get type
|
||||
form_type = input_form.keys()[0]
|
||||
default = input_form[form_type]['default']
|
||||
required = input_form[form_type]['required']
|
||||
label = input_form[form_type]['label']
|
||||
variable_name = input_form[form_type]['variable_name']
|
||||
options = input_form[form_type].get('options', [])
|
||||
if form_type == 'paragraph' or form_type == 'text-input':
|
||||
tool['parameters'].append(ToolParamter(
|
||||
name=variable_name,
|
||||
label=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
llm_description=label,
|
||||
form=ToolParamter.ToolParameterForm.FORM,
|
||||
type=ToolParamter.ToolParameterType.STRING,
|
||||
required=required,
|
||||
default=default
|
||||
))
|
||||
elif form_type == 'select':
|
||||
tool['parameters'].append(ToolParamter(
|
||||
name=variable_name,
|
||||
label=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
human_description=I18nObject(
|
||||
en_US=label,
|
||||
zh_Hans=label
|
||||
),
|
||||
llm_description=label,
|
||||
form=ToolParamter.ToolParameterForm.FORM,
|
||||
type=ToolParamter.ToolParameterType.SELECT,
|
||||
required=required,
|
||||
default=default,
|
||||
options=[ToolParamterOption(
|
||||
value=option,
|
||||
label=I18nObject(
|
||||
en_US=option,
|
||||
zh_Hans=option
|
||||
)
|
||||
) for option in options]
|
||||
))
|
||||
|
||||
tools.append(Tool(**tool))
|
||||
return tools
|
||||
@ -0,0 +1,26 @@
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from typing import List
|
||||
|
||||
position = {
|
||||
'google': 1,
|
||||
'wikipedia': 2,
|
||||
'dalle': 3,
|
||||
'webscraper': 4,
|
||||
'wolframalpha': 5,
|
||||
'chart': 6,
|
||||
'time': 7,
|
||||
'yahoo': 8,
|
||||
'stablediffusion': 9,
|
||||
'vectorizer': 10,
|
||||
'youtube': 11,
|
||||
}
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
@staticmethod
|
||||
def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
|
||||
def sort_compare(provider: UserToolProvider) -> int:
|
||||
return position.get(provider.name, 10000)
|
||||
|
||||
sorted_providers = sorted(providers, key=sort_compare)
|
||||
|
||||
return sorted_providers
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
@ -0,0 +1,24 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.chart.tools.line import LinearChartTool
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
# use a business theme
|
||||
plt.style.use('seaborn-v0_8-darkgrid')
|
||||
|
||||
class ChartProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
LinearChartTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"data": "1,3,5,7,9,2,4,6,8,10",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -0,0 +1,11 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: chart
|
||||
label:
|
||||
en_US: ChartGenerator
|
||||
zh_Hans: 图表生成
|
||||
description:
|
||||
en_US: Chart Generator is a tool for generating statistical charts like bar chart, line chart, pie chart, etc.
|
||||
zh_Hans: 图表生成是一个用于生成可视化图表的工具,你可以通过它来生成柱状图、折线图、饼图等各类图表
|
||||
icon: icon.png
|
||||
credentails_for_provider:
|
||||
@ -0,0 +1,47 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class BarChartTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
data = tool_paramters.get('data', '')
|
||||
if not data:
|
||||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all([i.isdigit() for i in data]):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
axis = tool_paramters.get('x_axis', None) or None
|
||||
if axis:
|
||||
axis = axis.split(';')
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
flg, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if axis:
|
||||
axis = [label[:10] + '...' if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha='right')
|
||||
ax.bar(axis, data)
|
||||
else:
|
||||
ax.bar(range(len(data)), data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message('the bar chart is saved as an image.'),
|
||||
self.create_blob_message(blob=buf.read(),
|
||||
meta={'mime_type': 'image/png'})
|
||||
]
|
||||
|
||||
@ -0,0 +1,35 @@
|
||||
identity:
|
||||
name: bar_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Bar Chart
|
||||
zh_Hans: 柱状图
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Bar chart
|
||||
zh_Hans: 柱状图
|
||||
llm: generate a bar chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
human_description:
|
||||
en_US: data for generating bar chart
|
||||
zh_Hans: 用于生成柱状图的数据
|
||||
llm_description: data for generating bar chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: x_axis
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: X Axis
|
||||
zh_Hans: x 轴
|
||||
human_description:
|
||||
en_US: X axis for bar chart
|
||||
zh_Hans: 柱状图的 x 轴
|
||||
llm_description: x axis for bar chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
|
||||
form: llm
|
||||
@ -0,0 +1,49 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class LinearChartTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
data = tool_paramters.get('data', '')
|
||||
if not data:
|
||||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
|
||||
axis = tool_paramters.get('x_axis', None) or None
|
||||
if axis:
|
||||
axis = axis.split(';')
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all([i.isdigit() for i in data]):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
flg, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if axis:
|
||||
axis = [label[:10] + '...' if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha='right')
|
||||
ax.plot(axis, data)
|
||||
else:
|
||||
ax.plot(data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message('the linear chart is saved as an image.'),
|
||||
self.create_blob_message(blob=buf.read(),
|
||||
meta={'mime_type': 'image/png'})
|
||||
]
|
||||
|
||||
@ -0,0 +1,35 @@
|
||||
identity:
|
||||
name: line_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Linear Chart
|
||||
zh_Hans: 线性图表
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: linear chart
|
||||
zh_Hans: 线性图表
|
||||
llm: generate a linear chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
human_description:
|
||||
en_US: data for generating linear chart
|
||||
zh_Hans: 用于生成线性图表的数据
|
||||
llm_description: data for generating linear chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: x_axis
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: X Axis
|
||||
zh_Hans: x 轴
|
||||
human_description:
|
||||
en_US: X axis for linear chart
|
||||
zh_Hans: 线性图表的 x 轴
|
||||
llm_description: x axis for linear chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
|
||||
form: llm
|
||||
@ -0,0 +1,46 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
class PieChartTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
data = tool_paramters.get('data', '')
|
||||
if not data:
|
||||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
categories = tool_paramters.get('categories', None) or None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all([i.isdigit() for i in data]):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
flg, ax = plt.subplots()
|
||||
|
||||
if categories:
|
||||
categories = categories.split(';')
|
||||
if len(categories) != len(data):
|
||||
categories = None
|
||||
|
||||
if categories:
|
||||
ax.pie(data, labels=categories)
|
||||
else:
|
||||
ax.pie(data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message('the pie chart is saved as an image.'),
|
||||
self.create_blob_message(blob=buf.read(),
|
||||
meta={'mime_type': 'image/png'})
|
||||
]
|
||||
@ -0,0 +1,35 @@
|
||||
identity:
|
||||
name: pie_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Pie Chart
|
||||
zh_Hans: 饼图
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Pie chart
|
||||
zh_Hans: 饼图
|
||||
llm: generate a pie chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
human_description:
|
||||
en_US: data for generating pie chart
|
||||
zh_Hans: 用于生成饼图的数据
|
||||
llm_description: data for generating pie chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: categories
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Categories
|
||||
zh_Hans: 分类
|
||||
human_description:
|
||||
en_US: Categories for pie chart
|
||||
zh_Hans: 饼图的分类
|
||||
llm_description: categories for pie chart, categories should be a string contains a list of texts like "a;b;c;1;2" in order to match the data, each category should be split by ";"
|
||||
form: llm
|
||||
|
After Width: | Height: | Size: 153 KiB |
@ -0,0 +1,23 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.provider.builtin.dalle.tools.dalle2 import DallE2Tool
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class DALLEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
DallE2Tool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"prompt": "cute girl, blue eyes, white hair, anime style",
|
||||
"size": "small",
|
||||
"n": 1
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -0,0 +1,47 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: dalle
|
||||
label:
|
||||
en_US: DALL-E
|
||||
zh_Hans: DALL-E 绘画
|
||||
description:
|
||||
en_US: DALL-E art
|
||||
zh_Hans: DALL-E 绘画
|
||||
icon: icon.png
|
||||
credentails_for_provider:
|
||||
openai_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: OpenAI API key
|
||||
zh_Hans: OpenAI API key
|
||||
help:
|
||||
en_US: Please input your OpenAI API key
|
||||
zh_Hans: 请输入你的 OpenAI API key
|
||||
placeholder:
|
||||
en_US: Please input your OpenAI API key
|
||||
zh_Hans: 请输入你的 OpenAI API key
|
||||
openai_organizaion_id:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: OpenAI organization ID
|
||||
zh_Hans: OpenAI organization ID
|
||||
help:
|
||||
en_US: Please input your OpenAI organization ID
|
||||
zh_Hans: 请输入你的 OpenAI organization ID
|
||||
placeholder:
|
||||
en_US: Please input your OpenAI organization ID
|
||||
zh_Hans: 请输入你的 OpenAI organization ID
|
||||
openai_base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: OpenAI base URL
|
||||
zh_Hans: OpenAI base URL
|
||||
help:
|
||||
en_US: Please input your OpenAI base URL
|
||||
zh_Hans: 请输入你的 OpenAI base URL
|
||||
placeholder:
|
||||
en_US: Please input your OpenAI base URL
|
||||
zh_Hans: 请输入你的 OpenAI base URL
|
||||
@ -0,0 +1,66 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from base64 import b64decode
|
||||
from os.path import join
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
class DallE2Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
openai_organization = self.runtime.credentials.get('openai_organizaion_id', None)
|
||||
if not openai_organization:
|
||||
openai_organization = None
|
||||
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
|
||||
if not openai_base_url:
|
||||
openai_base_url = None
|
||||
else:
|
||||
openai_base_url = join(openai_base_url, 'v1')
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
base_url=openai_base_url,
|
||||
organization=openai_organization
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
'small': '256x256',
|
||||
'medium': '512x512',
|
||||
'large': '1024x1024',
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_paramters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_paramters.get('size', 'large')]
|
||||
|
||||
# get n
|
||||
n = tool_paramters.get('n', 1)
|
||||
|
||||
# call openapi dalle2
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
model='dall-e-2',
|
||||
size=size,
|
||||
n=n,
|
||||
response_format='b64_json'
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
|
||||
return result
|
||||
@ -0,0 +1,74 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from base64 import b64decode
|
||||
from os.path import join
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
openai_organization = self.runtime.credentials.get('openai_organizaion_id', None)
|
||||
if not openai_organization:
|
||||
openai_organization = None
|
||||
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
|
||||
if not openai_base_url:
|
||||
openai_base_url = None
|
||||
else:
|
||||
openai_base_url = join(openai_base_url, 'v1')
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
base_url=openai_base_url,
|
||||
organization=openai_organization
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
'square': '1024x1024',
|
||||
'vertical': '1024x1792',
|
||||
'horizontal': '1792x1024',
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_paramters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_paramters.get('size', 'square')]
|
||||
# get n
|
||||
n = tool_paramters.get('n', 1)
|
||||
# get quality
|
||||
quality = tool_paramters.get('quality', 'standard')
|
||||
if quality not in ['standard', 'hd']:
|
||||
return self.create_text_message('Invalid quality')
|
||||
# get style
|
||||
style = tool_paramters.get('style', 'vivid')
|
||||
if style not in ['natural', 'vivid']:
|
||||
return self.create_text_message('Invalid style')
|
||||
|
||||
# call openapi dalle3
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
model='dall-e-3',
|
||||
size=size,
|
||||
n=n,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format='b64_json'
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
|
||||
return result
|
||||
@ -0,0 +1,6 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="25" viewBox="0 0 24 25" fill="none">
|
||||
<path d="M22.501 12.7332C22.501 11.8699 22.4296 11.2399 22.2748 10.5865H12.2153V14.4832H18.12C18.001 15.4515 17.3582 16.9099 15.9296 17.8898L15.9096 18.0203L19.0902 20.435L19.3106 20.4565C21.3343 18.6249 22.501 15.9298 22.501 12.7332Z" fill="#4285F4"/>
|
||||
<path d="M12.214 23C15.1068 23 17.5353 22.0666 19.3092 20.4567L15.9282 17.8899C15.0235 18.5083 13.8092 18.9399 12.214 18.9399C9.38069 18.9399 6.97596 17.1083 6.11874 14.5766L5.99309 14.5871L2.68583 17.0954L2.64258 17.2132C4.40446 20.6433 8.0235 23 12.214 23Z" fill="#34A853"/>
|
||||
<path d="M6.12046 14.5766C5.89428 13.9233 5.76337 13.2233 5.76337 12.5C5.76337 11.7766 5.89428 11.0766 6.10856 10.4233L6.10257 10.2841L2.75386 7.7355L2.64429 7.78658C1.91814 9.20993 1.50146 10.8083 1.50146 12.5C1.50146 14.1916 1.91814 15.7899 2.64429 17.2132L6.12046 14.5766Z" fill="#FBBC05"/>
|
||||
<path d="M12.2141 6.05997C14.2259 6.05997 15.583 6.91163 16.3569 7.62335L19.3807 4.73C17.5236 3.03834 15.1069 2 12.2141 2C8.02353 2 4.40447 4.35665 2.64258 7.78662L6.10686 10.4233C6.97598 7.89166 9.38073 6.05997 12.2141 6.05997Z" fill="#EB4335"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.2 KiB |
@ -0,0 +1,23 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
class GoogleProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
GoogleSearchTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"query": "test",
|
||||
"result_type": "link"
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -0,0 +1,24 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: google
|
||||
label:
|
||||
en_US: Google
|
||||
zh_Hans: Google
|
||||
description:
|
||||
en_US: Google
|
||||
zh_Hans: GoogleSearch
|
||||
icon: icon.svg
|
||||
credentails_for_provider:
|
||||
serpapi_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: SerpApi API key
|
||||
zh_Hans: SerpApi API key
|
||||
placeholder:
|
||||
en_US: Please input your SerpApi API key
|
||||
zh_Hans: 请输入你的 SerpApi API key
|
||||
help:
|
||||
en_US: Get your SerpApi API key from SerpApi
|
||||
zh_Hans: 从 SerpApi 获取您的 SerpApi API key
|
||||
url: https://serpapi.com/manage-api-key
|
||||
@ -0,0 +1,163 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from serpapi import GoogleSearch
|
||||
|
||||
class HiddenPrints:
|
||||
"""Context manager to hide prints."""
|
||||
|
||||
def __enter__(self) -> None:
|
||||
"""Open file to pipe stdout to."""
|
||||
self._original_stdout = sys.stdout
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
"""Close file that stdout was piped to."""
|
||||
sys.stdout.close()
|
||||
sys.stdout = self._original_stdout
|
||||
|
||||
|
||||
class SerpAPI:
|
||||
"""
|
||||
SerpAPI tool provider.
|
||||
"""
|
||||
|
||||
search_engine: Any #: :meta private:
|
||||
serpapi_api_key: str = None
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
"""Initialize SerpAPI tool provider."""
|
||||
self.serpapi_api_key = api_key
|
||||
self.search_engine = GoogleSearch
|
||||
|
||||
def run(self, query: str, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result."""
|
||||
typ = kwargs.get("result_type", "text")
|
||||
return self._process_response(self.results(query), typ=typ)
|
||||
|
||||
def results(self, query: str) -> dict:
|
||||
"""Run query through SerpAPI and return the raw result."""
|
||||
params = self.get_params(query)
|
||||
with HiddenPrints():
|
||||
search = self.search_engine(params)
|
||||
res = search.get_dict()
|
||||
return res
|
||||
|
||||
def get_params(self, query: str) -> Dict[str, str]:
|
||||
"""Get parameters for SerpAPI."""
|
||||
_params = {
|
||||
"api_key": self.serpapi_api_key,
|
||||
"q": query,
|
||||
}
|
||||
params = {
|
||||
"engine": "google",
|
||||
"google_domain": "google.com",
|
||||
"gl": "us",
|
||||
"hl": "en",
|
||||
**_params
|
||||
}
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict, typ: str) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
if "error" in res.keys():
|
||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
||||
|
||||
if typ == "text":
|
||||
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
|
||||
res["answer_box"] = res["answer_box"][0]
|
||||
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
"shopping_results" in res.keys()
|
||||
and "title" in res["shopping_results"][0].keys()
|
||||
):
|
||||
toret = res["shopping_results"][:3]
|
||||
elif (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif "snippet" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["snippet"]
|
||||
elif "link" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["link"]
|
||||
elif (
|
||||
"images_results" in res.keys()
|
||||
and "thumbnail" in res["images_results"][0].keys()
|
||||
):
|
||||
thumbnails = [item["thumbnail"] for item in res["images_results"][:10]]
|
||||
toret = thumbnails
|
||||
else:
|
||||
toret = "No good search result found"
|
||||
elif typ == "link":
|
||||
if "knowledge_graph" in res.keys() and "title" in res["knowledge_graph"].keys() \
|
||||
and "description_link" in res["knowledge_graph"].keys():
|
||||
toret = res["knowledge_graph"]["description_link"]
|
||||
elif "knowledge_graph" in res.keys() and "see_results_about" in res["knowledge_graph"].keys() \
|
||||
and len(res["knowledge_graph"]["see_results_about"]) > 0:
|
||||
see_result_about = res["knowledge_graph"]["see_results_about"]
|
||||
toret = ""
|
||||
for item in see_result_about:
|
||||
if "name" not in item.keys() or "link" not in item.keys():
|
||||
continue
|
||||
toret += f"[{item['name']}]({item['link']})\n"
|
||||
elif "organic_results" in res.keys() and len(res["organic_results"]) > 0:
|
||||
organic_results = res["organic_results"]
|
||||
toret = ""
|
||||
for item in organic_results:
|
||||
if "title" not in item.keys() or "link" not in item.keys():
|
||||
continue
|
||||
toret += f"[{item['title']}]({item['link']})\n"
|
||||
elif "related_questions" in res.keys() and len(res["related_questions"]) > 0:
|
||||
related_questions = res["related_questions"]
|
||||
toret = ""
|
||||
for item in related_questions:
|
||||
if "question" not in item.keys() or "link" not in item.keys():
|
||||
continue
|
||||
toret += f"[{item['question']}]({item['link']})\n"
|
||||
elif "related_searches" in res.keys() and len(res["related_searches"]) > 0:
|
||||
related_searches = res["related_searches"]
|
||||
toret = ""
|
||||
for item in related_searches:
|
||||
if "query" not in item.keys() or "link" not in item.keys():
|
||||
continue
|
||||
toret += f"[{item['query']}]({item['link']})\n"
|
||||
else:
|
||||
toret = "No good search result found"
|
||||
return toret
|
||||
|
||||
class GoogleSearchTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
query = tool_paramters['query']
|
||||
result_type = tool_paramters['result_type']
|
||||
api_key = self.runtime.credentials['serpapi_api_key']
|
||||
result = SerpAPI(api_key).run(query, result_type=result_type)
|
||||
if result_type == 'text':
|
||||
return self.create_text_message(text=result)
|
||||
return self.create_link_message(link=result)
|
||||
|
||||
@ -0,0 +1,43 @@
|
||||
identity:
|
||||
name: google_search
|
||||
author: Dify
|
||||
label:
|
||||
en_US: GoogleSearch
|
||||
zh_Hans: 谷歌搜索
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
|
||||
llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
human_description:
|
||||
en_US: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
llm_description: key words for searching
|
||||
form: llm
|
||||
- name: result_type
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: text
|
||||
label:
|
||||
en_US: text
|
||||
zh_Hans: 文本
|
||||
- value: link
|
||||
label:
|
||||
en_US: link
|
||||
zh_Hans: 链接
|
||||
default: link
|
||||
label:
|
||||
en_US: Result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: used for selecting the result type, text or link
|
||||
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
|
||||
form: form
|
||||
|
After Width: | Height: | Size: 16 KiB |
@ -0,0 +1,26 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import StableDiffusionTool
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class StableDiffusionProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
StableDiffusionTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_paramters={
|
||||
"prompt": "cat",
|
||||
"lora": "",
|
||||
"steps": 1,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -0,0 +1,29 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: stablediffusion
|
||||
label:
|
||||
en_US: Stable Diffusion
|
||||
zh_Hans: Stable Diffusion
|
||||
description:
|
||||
en_US: Stable Diffusion is a tool for generating images which can be deployed locally.
|
||||
zh_Hans: Stable Diffusion 是一个可以在本地部署的图片生成的工具。
|
||||
icon: icon.png
|
||||
credentails_for_provider:
|
||||
base_url:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_Hans: StableDiffusion服务器的Base URL
|
||||
placeholder:
|
||||
en_US: Please input your StableDiffusion server's Base URL
|
||||
zh_Hans: 请输入你的 StableDiffusion 服务器的 Base URL
|
||||
model:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Model
|
||||
zh_Hans: 模型
|
||||
placeholder:
|
||||
en_US: Please input your model
|
||||
zh_Hans: 请输入你的模型名称
|
||||
@ -0,0 +1,77 @@
|
||||
identity:
|
||||
name: stable_diffusion
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Stable Diffusion WebUI
|
||||
zh_Hans: Stable Diffusion WebUI
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating images which can be deployed locally, you can use stable-diffusion-webui to deploy it.
|
||||
zh_Hans: 一个可以在本地部署的图片生成的工具,您可以使用 stable-diffusion-webui 来部署它。
|
||||
llm: draw the image you want based on your prompt.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of Stable Diffusion
|
||||
zh_Hans: 图像提示词,您可以查看 Stable Diffusion 的官方文档
|
||||
llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
||||
form: llm
|
||||
- name: lora
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora
|
||||
zh_Hans: Lora
|
||||
human_description:
|
||||
en_US: Lora
|
||||
zh_Hans: Lora
|
||||
form: form
|
||||
- name: steps
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Steps
|
||||
zh_Hans: Steps
|
||||
human_description:
|
||||
en_US: Steps
|
||||
zh_Hans: Steps
|
||||
form: form
|
||||
default: 10
|
||||
- name: width
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Width
|
||||
zh_Hans: Width
|
||||
human_description:
|
||||
en_US: Width
|
||||
zh_Hans: Width
|
||||
form: form
|
||||
default: 1024
|
||||
- name: height
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Height
|
||||
zh_Hans: Height
|
||||
human_description:
|
||||
en_US: Height
|
||||
zh_Hans: Height
|
||||
form: form
|
||||
default: 1024
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Negative prompt
|
||||
zh_Hans: Negative prompt
|
||||
human_description:
|
||||
en_US: Negative prompt
|
||||
zh_Hans: Negative prompt
|
||||
form: form
|
||||
default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines
|
||||
@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M0.666992 8.00008C0.666992 3.94999 3.95024 0.666748 8.00033 0.666748C12.0504 0.666748 15.3337 3.94999 15.3337 8.00008C15.3337 12.0502 12.0504 15.3334 8.00033 15.3334C3.95024 15.3334 0.666992 12.0502 0.666992 8.00008ZM8.66699 4.00008C8.66699 3.63189 8.36852 3.33341 8.00033 3.33341C7.63213 3.33341 7.33366 3.63189 7.33366 4.00008V8.00008C7.33366 8.2526 7.47633 8.48344 7.70218 8.59637L10.3688 9.9297C10.6982 10.0944 11.0986 9.96088 11.2633 9.63156C11.4279 9.30224 11.2945 8.90179 10.9651 8.73713L8.66699 7.58806V4.00008Z" fill="#EC4A0A"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 691 B |
@ -0,0 +1,16 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from core.tools.provider.builtin.time.tools.current_time import CurrentTimeTool
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||
try:
|
||||
CurrentTimeTool().invoke(
|
||||
user_id='',
|
||||
tool_paramters={},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -0,0 +1,11 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: time
|
||||
label:
|
||||
en_US: CurrentTime
|
||||
zh_Hans: 时间
|
||||
description:
|
||||
en_US: A tool for getting the current time.
|
||||
zh_Hans: 一个用于获取当前时间的工具。
|
||||
icon: icon.svg
|
||||
credentails_for_provider:
|
||||
@ -0,0 +1,17 @@
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class CurrentTimeTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_paramters: Dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
return self.create_text_message(f'{datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")}')
|
||||
|
||||
@ -0,0 +1,12 @@
|
||||
identity:
|
||||
name: current_time
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Current Time
|
||||
zh_Hans: 获取当前时间
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for getting the current time.
|
||||
zh_Hans: 一个用于获取当前时间的工具。
|
||||
llm: A tool for getting the current time.
|
||||
parameters:
|
||||
|
After Width: | Height: | Size: 1.8 KiB |
@ -0,0 +1 @@
|
||||
VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC'
|
||||
@ -0,0 +1,74 @@
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter
|
||||
from core.tools.provider.builtin.vectorizer.tools.test_data import VECTORIZER_ICON_PNG
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from httpx import post
|
||||
from base64 import b64decode
|
||||
|
||||
class VectorizerTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key_name = self.runtime.credentials.get('api_key_name', None)
|
||||
api_key_value = self.runtime.credentials.get('api_key_value', None)
|
||||
mode = tool_paramters.get('mode', 'test')
|
||||
if mode == 'production':
|
||||
mode = 'preview'
|
||||
|
||||
if not api_key_name or not api_key_value:
|
||||
raise ToolProviderCredentialValidationError('Please input api key name and value')
|
||||
|
||||
image_id = tool_paramters.get('image_id', '')
|
||||
if not image_id:
|
||||
return self.create_text_message('Please input image id')
|
||||
|
||||
if image_id.startswith('__test_'):
|
||||
image_binary = b64decode(VECTORIZER_ICON_PNG)
|
||||
else:
|
||||
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
|
||||
if not image_binary:
|
||||
return self.create_text_message('Image not found, please request user to generate image firstly.')
|
||||
|
||||
response = post(
|
||||
'https://vectorizer.ai/api/v1/vectorize',
|
||||
files={
|
||||
'image': image_binary
|
||||
},
|
||||
data={
|
||||
'mode': mode
|
||||
} if mode == 'test' else {},
|
||||
auth=(api_key_name, api_key_value),
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
return [
|
||||
self.create_text_message('the vectorized svg is saved as an image.'),
|
||||
self.create_blob_message(blob=response.content,
|
||||
meta={'mime_type': 'image/svg+xml'})
|
||||
]
|
||||
|
||||
def get_runtime_parameters(self) -> List[ToolParamter]:
|
||||
"""
|
||||
override the runtime parameters
|
||||
"""
|
||||
return [
|
||||
ToolParamter.get_simple_instance(
|
||||
name='image_id',
|
||||
llm_description=f'the image id that you want to vectorize, \
|
||||
and the image id should be specified in \
|
||||
{[i.name for i in self.list_default_image_variables()]}',
|
||||
type=ToolParamter.ToolParameterType.SELECT,
|
||||
required=True,
|
||||
options=[i.name for i in self.list_default_image_variables()]
|
||||
)
|
||||
]
|
||||
|
||||
def is_tool_avaliable(self) -> bool:
|
||||
return len(self.list_default_image_variables()) > 0
|
||||