feat: backwards invoke tools

pull/12372/head
Yeuoly 2 years ago
parent 699d41deec
commit 118fa66567
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -1,5 +1,3 @@
import time
from flask_restful import Resource from flask_restful import Resource
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
@ -10,6 +8,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.encrypt import PluginEncrypter from core.plugin.encrypt import PluginEncrypter
from core.plugin.entities.request import ( from core.plugin.entities.request import (
RequestInvokeApp, RequestInvokeApp,
@ -24,7 +23,7 @@ from core.plugin.entities.request import (
RequestInvokeTool, RequestInvokeTool,
RequestInvokeTTS, RequestInvokeTTS,
) )
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import compact_generate_response from libs.helper import compact_generate_response
from models.account import Tenant from models.account import Tenant
@ -138,17 +137,16 @@ class PluginInvokeToolApi(Resource):
@plugin_data(payload_type=RequestInvokeTool) @plugin_data(payload_type=RequestInvokeTool)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool):
def generator(): def generator():
for i in range(10): return PluginToolBackwardsInvocation.convert_to_event_stream(
time.sleep(0.1) PluginToolBackwardsInvocation.invoke_tool(
yield ( tenant_id=tenant_model.id,
ToolInvokeMessage( user_id=user_id,
type=ToolInvokeMessage.MessageType.TEXT, tool_type=ToolProviderType.value_of(payload.tool_type),
message=ToolInvokeMessage.TextMessage(text="helloworld"), provider=payload.provider,
) tool_name=payload.tool,
.model_dump_json() tool_parameters=payload.tool_parameters,
.encode() ),
+ b"\n\n" )
)
return compact_generate_response(generator()) return compact_generate_response(generator())

@ -0,0 +1,45 @@
from collections.abc import Generator
from typing import Any
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
"""
Backwards invocation for plugin tools.
"""
@classmethod
def invoke_tool(
cls,
tenant_id: str,
user_id: str,
tool_type: ToolProviderType,
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tool
"""
# get tool runtime
try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
tool_type, tenant_id, provider, tool_name, tool_parameters
)
response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
)
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
response, user_id=user_id, tenant_id=tenant_id
)
return response
except Exception as e:
raise e

@ -32,6 +32,11 @@ class RequestInvokeTool(BaseModel):
Request to invoke a tool Request to invoke a tool
""" """
tool_type: Literal["builtin", "workflow", "api"]
provider: str
tool: str
tool_parameters: dict
class BaseRequestInvokeModel(BaseModel): class BaseRequestInvokeModel(BaseModel):
provider: str provider: str

@ -378,6 +378,7 @@ class ToolInvokeFrom(Enum):
WORKFLOW = "workflow" WORKFLOW = "workflow"
AGENT = "agent" AGENT = "agent"
PLUGIN = "plugin"
class ToolProviderID: class ToolProviderID:

@ -131,7 +131,7 @@ class ToolEngine:
return error_response, [], ToolInvokeMeta.error_instance(error_response) return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod @staticmethod
def workflow_invoke( def generic_invoke(
tool: Tool, tool: Tool,
tool_parameters: dict[str, Any], tool_parameters: dict[str, Any],
user_id: str, user_id: str,

@ -365,6 +365,40 @@ class ToolManager:
tool_runtime.runtime.runtime_parameters.update(runtime_parameters) tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
return tool_runtime return tool_runtime
@classmethod
def get_tool_runtime_from_plugin(
cls,
tool_type: ToolProviderType,
tenant_id: str,
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
) -> Tool:
"""
get tool runtime from plugin
"""
tool_entity = cls.get_tool_runtime(
provider_type=tool_type,
provider_id=provider,
tool_name=tool_name,
tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = cls._init_runtime_parameter(parameter, tool_parameters)
runtime_parameters[parameter.name] = value
if not tool_entity.runtime:
raise Exception("tool missing runtime")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@classmethod @classmethod
def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]: def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
""" """

@ -66,7 +66,7 @@ class ToolNode(BaseNode):
) )
try: try:
message_stream = ToolEngine.workflow_invoke( message_stream = ToolEngine.generic_invoke(
tool=tool_runtime, tool=tool_runtime,
tool_parameters=parameters, tool_parameters=parameters,
user_id=self.user_id, user_id=self.user_id,

Loading…
Cancel
Save