refactor: tool response to generator

pull/9184/head
Yeuoly 2 years ago
parent 364df36ac4
commit 563d81277b
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -23,6 +23,8 @@ class PluginInvokeModelApi(Resource):
args = parser.parse_args() args = parser.parse_args()
class PluginInvokeToolApi(Resource): class PluginInvokeToolApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only

@ -1,14 +1,16 @@
from enum import Enum from enum import Enum
from typing import Any, Literal, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolProviderType
class AgentToolEntity(BaseModel): class AgentToolEntity(BaseModel):
""" """
Agent Tool Entity. Agent Tool Entity.
""" """
provider_type: Literal["builtin", "api", "workflow"] provider_type: ToolProviderType
provider_id: str provider_id: str
tool_name: str tool_name: str
tool_parameters: dict[str, Any] = {} tool_parameters: dict[str, Any] = {}

@ -0,0 +1,5 @@
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
class DifyPluginCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out."""

@ -1,4 +1,5 @@
import json import json
from collections.abc import Generator
from os import getenv from os import getenv
from typing import Any from typing import Any
from urllib.parse import urlencode from urllib.parse import urlencode
@ -269,7 +270,7 @@ class ApiTool(Tool):
except ValueError as e: except ValueError as e:
return value return value
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
""" """
invoke http request invoke http request
""" """
@ -283,4 +284,4 @@ class ApiTool(Tool):
response = self.validate_and_parse_response(response) response = self.validate_and_parse_response(response)
# assemble invoke message # assemble invoke message
return self.create_text_message(response) yield self.create_text_message(response)

@ -1,3 +1,4 @@
from collections.abc import Generator
from typing import Any from typing import Any
from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -86,7 +87,7 @@ class DatasetRetrieverTool(Tool):
def tool_provider_type(self) -> ToolProviderType: def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL return ToolProviderType.DATASET_RETRIEVAL
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
""" """
invoke dataset retriever tool invoke dataset retriever tool
""" """
@ -97,7 +98,7 @@ class DatasetRetrieverTool(Tool):
# invoke dataset retriever tool # invoke dataset retriever tool
result = self.retrival_tool._run(query=query) result = self.retrival_tool._run(query=query)
return self.create_text_message(text=result) yield self.create_text_message(text=result)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
""" """

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -190,7 +191,7 @@ class Tool(BaseModel, ABC):
return result return result
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
# update tool_parameters # update tool_parameters
if self.runtime.runtime_parameters: if self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters) tool_parameters.update(self.runtime.runtime_parameters)
@ -203,9 +204,6 @@ class Tool(BaseModel, ABC):
tool_parameters=tool_parameters, tool_parameters=tool_parameters,
) )
if not isinstance(result, list):
result = [result]
return result return result
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
@ -221,7 +219,7 @@ class Tool(BaseModel, ABC):
return result return result
@abstractmethod @abstractmethod
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
pass pass
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:

@ -1,5 +1,6 @@
import json import json
import logging import logging
from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from typing import Any, Union from typing import Any, Union
@ -34,7 +35,7 @@ class WorkflowTool(Tool):
def _invoke( def _invoke(
self, user_id: str, tool_parameters: dict[str, Any] self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: ) -> Generator[ToolInvokeMessage, None, None]:
""" """
invoke the tool invoke the tool
""" """
@ -46,6 +47,7 @@ class WorkflowTool(Tool):
from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator
generator = WorkflowAppGenerator() generator = WorkflowAppGenerator()
result = generator.generate( result = generator.generate(
app_model=app, app_model=app,
workflow=workflow, workflow=workflow,
@ -64,16 +66,12 @@ class WorkflowTool(Tool):
if data.get('error'): if data.get('error'):
raise Exception(data.get('error')) raise Exception(data.get('error'))
result = []
outputs = data.get('outputs', {}) outputs = data.get('outputs', {})
outputs, files = self._extract_files(outputs) outputs, files = self._extract_files(outputs)
for file in files: for file in files:
result.append(self.create_file_var_message(file)) yield self.create_file_var_message(file)
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
return result yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
def _get_user(self, user_id: str) -> Union[EndUser, Account]: def _get_user(self, user_id: str) -> Union[EndUser, Account]:
""" """

@ -1,4 +1,5 @@
import json import json
from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from datetime import datetime, timezone from datetime import datetime, timezone
from mimetypes import guess_type from mimetypes import guess_type
@ -8,6 +9,7 @@ from yarl import URL
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod from core.file.file_obj import FileTransferMethod
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
@ -64,16 +66,25 @@ class ToolEngine:
tool_inputs=tool_parameters tool_inputs=tool_parameters
) )
meta, response = ToolEngine._invoke(tool, tool_parameters, user_id) messages = ToolEngine._invoke(tool, tool_parameters, user_id)
response = ToolFileMessageTransformer.transform_tool_invoke_messages( invocation_meta_dict = {'meta': None}
messages=response,
def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]):
for message in messages:
if isinstance(message, ToolInvokeMeta):
invocation_meta_dict['meta'] = message
else:
yield message
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=message_callback(invocation_meta_dict, messages),
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
conversation_id=message.conversation_id conversation_id=message.conversation_id
) )
# extract binary data from tool invoke message # extract binary data from tool invoke message
binary_files = ToolEngine._extract_tool_response_binary(response) binary_files = ToolEngine._extract_tool_response_binary(messages)
# create message file # create message file
message_files = ToolEngine._create_message_files( message_files = ToolEngine._create_message_files(
tool_messages=binary_files, tool_messages=binary_files,
@ -82,7 +93,9 @@ class ToolEngine:
user_id=user_id user_id=user_id
) )
plain_text = ToolEngine._convert_tool_response_to_str(response) plain_text = ToolEngine._convert_tool_response_to_str(messages)
meta = invocation_meta_dict['meta']
# hit the callback handler # hit the callback handler
agent_tool_callback.on_tool_end( agent_tool_callback.on_tool_end(
@ -127,7 +140,7 @@ class ToolEngine:
user_id: str, workflow_id: str, user_id: str, workflow_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int, workflow_call_depth: int,
) -> list[ToolInvokeMessage]: ) -> Generator[ToolInvokeMessage, None, None]:
""" """
Workflow invokes the tool with the given arguments. Workflow invokes the tool with the given arguments.
""" """
@ -155,9 +168,37 @@ class ToolEngine:
workflow_tool_callback.on_tool_error(e) workflow_tool_callback.on_tool_error(e)
raise e raise e
@staticmethod
def plugin_invoke(tool: Tool, tool_parameters: dict, user_id: str,
callback: DifyPluginCallbackHandler
) -> Generator[ToolInvokeMessage, None, None]:
"""
Plugin invokes the tool with the given arguments.
"""
try:
# hit the callback handler
callback.on_tool_start(
tool_name=tool.identity.name,
tool_inputs=tool_parameters
)
response = tool.invoke(user_id, tool_parameters)
# hit the callback handler
callback.on_tool_end(
tool_name=tool.identity.name,
tool_inputs=tool_parameters,
tool_outputs=response,
)
return response
except Exception as e:
callback.on_tool_error(e)
raise e
@staticmethod @staticmethod
def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
-> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
""" """
Invoke the tool with the given arguments. Invoke the tool with the given arguments.
""" """
@ -170,15 +211,14 @@ class ToolEngine:
'tool_icon': tool.identity.icon 'tool_icon': tool.identity.icon
}) })
try: try:
response = tool.invoke(user_id, tool_parameters) yield from tool.invoke(user_id, tool_parameters)
except Exception as e: except Exception as e:
meta.error = str(e) meta.error = str(e)
raise ToolEngineInvokeError(meta) raise ToolEngineInvokeError(meta)
finally: finally:
ended_at = datetime.now(timezone.utc) ended_at = datetime.now(timezone.utc)
meta.time_cost = (ended_at - started_at).total_seconds() meta.time_cost = (ended_at - started_at).total_seconds()
yield meta
return meta, response
@staticmethod @staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:

@ -18,6 +18,7 @@ from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ToolInvokeFrom, ToolInvokeFrom,
ToolParameter, ToolParameter,
ToolProviderType,
) )
from core.tools.errors import ToolProviderNotFoundError from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.api_tool_provider import ApiToolProviderController
@ -26,6 +27,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
from core.tools.tool.api_tool import ApiTool from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ( from core.tools.utils.configuration import (
ToolConfigurationManager, ToolConfigurationManager,
@ -78,37 +80,13 @@ class ToolManager:
return tool return tool
@classmethod @classmethod
def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \ def get_tool_runtime(cls, provider_type: ToolProviderType,
-> Union[BuiltinTool, ApiTool]:
"""
get the tool
:param provider_type: the type of the provider
:param provider_name: the name of the provider
:param tool_name: the name of the tool
:return: the tool
"""
if provider_type == 'builtin':
return cls.get_builtin_tool(provider_id, tool_name)
elif provider_type == 'api':
if tenant_id is None:
raise ValueError('tenant id is required for api provider')
api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id)
return api_provider.get_tool(tool_name)
elif provider_type == 'app':
raise NotImplementedError('app provider not implemented')
else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@classmethod
def get_tool_runtime(cls, provider_type: str,
provider_id: str, provider_id: str,
tool_name: str, tool_name: str,
tenant_id: str, tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]: -> Union[BuiltinTool, ApiTool, WorkflowTool]:
""" """
get the tool runtime get the tool runtime
@ -118,7 +96,7 @@ class ToolManager:
:return: the tool :return: the tool
""" """
if provider_type == 'builtin': if provider_type == ToolProviderType.BUILT_IN:
builtin_tool = cls.get_builtin_tool(provider_id, tool_name) builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
# check if the builtin tool need credentials # check if the builtin tool need credentials
@ -155,7 +133,7 @@ class ToolManager:
'tool_invoke_from': tool_invoke_from, 'tool_invoke_from': tool_invoke_from,
}) })
elif provider_type == 'api': elif provider_type == ToolProviderType.API:
if tenant_id is None: if tenant_id is None:
raise ValueError('tenant id is required for api provider') raise ValueError('tenant id is required for api provider')
@ -171,7 +149,7 @@ class ToolManager:
'invoke_from': invoke_from, 'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from, 'tool_invoke_from': tool_invoke_from,
}) })
elif provider_type == 'workflow': elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = db.session.query(WorkflowToolProvider).filter( workflow_provider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == provider_id WorkflowToolProvider.id == provider_id
@ -190,10 +168,10 @@ class ToolManager:
'invoke_from': invoke_from, 'invoke_from': invoke_from,
'tool_invoke_from': tool_invoke_from, 'tool_invoke_from': tool_invoke_from,
}) })
elif provider_type == 'app': elif provider_type == ToolProviderType.APP:
raise NotImplementedError('app provider not implemented') raise NotImplementedError('app provider not implemented')
else: else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found') raise ToolProviderNotFoundError(f'provider type {provider_type.value} not found')
@classmethod @classmethod
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
@ -554,7 +532,7 @@ class ToolManager:
}) })
@classmethod @classmethod
def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
""" """
get the tool icon get the tool icon
@ -563,14 +541,12 @@ class ToolManager:
:param provider_id: the id of the provider :param provider_id: the id of the provider
:return: :return:
""" """
provider_type = provider_type if provider_type == ToolProviderType.BUILT_IN:
provider_id = provider_id
if provider_type == 'builtin':
return (current_app.config.get("CONSOLE_API_URL") return (current_app.config.get("CONSOLE_API_URL")
+ "/console/api/workspaces/current/tool-provider/builtin/" + "/console/api/workspaces/current/tool-provider/builtin/"
+ provider_id + provider_id
+ "/icon") + "/icon")
elif provider_type == 'api': elif provider_type == ToolProviderType.API:
try: try:
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
@ -582,7 +558,7 @@ class ToolManager:
"background": "#252525", "background": "#252525",
"content": "\ud83d\ude01" "content": "\ud83d\ude01"
} }
elif provider_type == 'workflow': elif provider_type == ToolProviderType.WORKFLOW:
provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == provider_id WorkflowToolProvider.id == provider_id

@ -9,6 +9,7 @@ from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolPr
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolParameter, ToolParameter,
ToolProviderCredentials, ToolProviderCredentials,
ToolProviderType,
) )
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
@ -108,7 +109,7 @@ class ToolParameterConfigurationManager(BaseModel):
tenant_id: str tenant_id: str
tool_runtime: Tool tool_runtime: Tool
provider_name: str provider_name: str
provider_type: str provider_type: ToolProviderType
identity_id: str identity_id: str
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
@ -191,7 +192,7 @@ class ToolParameterConfigurationManager(BaseModel):
""" """
cache = ToolParameterCache( cache = ToolParameterCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}', provider=f'{self.provider_type.value}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name, tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER, cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id identity_id=self.identity_id
@ -221,7 +222,7 @@ class ToolParameterConfigurationManager(BaseModel):
def delete_tool_parameters_cache(self): def delete_tool_parameters_cache(self):
cache = ToolParameterCache( cache = ToolParameterCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}', provider=f'{self.provider_type.value}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name, tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER, cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id identity_id=self.identity_id

@ -1,4 +1,5 @@
import logging import logging
from collections.abc import Generator
from mimetypes import guess_extension from mimetypes import guess_extension
from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.file.file_obj import FileTransferMethod, FileType, FileVar
@ -9,20 +10,18 @@ logger = logging.getLogger(__name__)
class ToolFileMessageTransformer: class ToolFileMessageTransformer:
@classmethod @classmethod
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None],
user_id: str, user_id: str,
tenant_id: str, tenant_id: str,
conversation_id: str) -> list[ToolInvokeMessage]: conversation_id: str) -> Generator[ToolInvokeMessage, None, None]:
""" """
Transform tool message and handle file download Transform tool message and handle file download
""" """
result = []
for message in messages: for message in messages:
if message.type == ToolInvokeMessage.MessageType.TEXT: if message.type == ToolInvokeMessage.MessageType.TEXT:
result.append(message) yield message
elif message.type == ToolInvokeMessage.MessageType.LINK: elif message.type == ToolInvokeMessage.MessageType.LINK:
result.append(message) yield message
elif message.type == ToolInvokeMessage.MessageType.IMAGE: elif message.type == ToolInvokeMessage.MessageType.IMAGE:
# try to download image # try to download image
try: try:
@ -35,20 +34,20 @@ class ToolFileMessageTransformer:
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
result.append(ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK, type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url, message=url,
save_as=message.save_as, save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
)) )
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
result.append(ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT, type=ToolInvokeMessage.MessageType.TEXT,
message=f"Failed to download image: {message.message}, you can try to download it yourself.", message=f"Failed to download image: {message.message}, you can try to download it yourself.",
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
save_as=message.save_as, save_as=message.save_as,
)) )
elif message.type == ToolInvokeMessage.MessageType.BLOB: elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage # get mime type and save blob to storage
mimetype = message.meta.get('mime_type', 'octet/stream') mimetype = message.meta.get('mime_type', 'octet/stream')
@ -67,42 +66,40 @@ class ToolFileMessageTransformer:
# check if file is image # check if file is image
if 'image' in mimetype: if 'image' in mimetype:
result.append(ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK, type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url, message=url,
save_as=message.save_as, save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
)) )
else: else:
result.append(ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK, type=ToolInvokeMessage.MessageType.LINK,
message=url, message=url,
save_as=message.save_as, save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
)) )
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
file_var: FileVar = message.meta.get('file_var') file_var: FileVar = message.meta.get('file_var')
if file_var: if file_var:
if file_var.transfer_method == FileTransferMethod.TOOL_FILE: if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
url = cls.get_tool_file_url(file_var.related_id, file_var.extension) url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
if file_var.type == FileType.IMAGE: if file_var.type == FileType.IMAGE:
result.append(ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK, type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url, message=url,
save_as=message.save_as, save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
)) )
else: else:
result.append(ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK, type=ToolInvokeMessage.MessageType.LINK,
message=url, message=url,
save_as=message.save_as, save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
)) )
else: else:
result.append(message) yield message
return result
@classmethod @classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:

@ -3,12 +3,13 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
class ToolEntity(BaseModel): class ToolEntity(BaseModel):
provider_id: str provider_id: str
provider_type: Literal['builtin', 'api', 'workflow'] provider_type: ToolProviderType
provider_name: str # redundancy provider_name: str # redundancy
tool_name: str tool_name: str
tool_label: str # redundancy tool_label: str # redundancy

@ -32,7 +32,7 @@ class ToolNode(BaseNode):
# fetch tool icon # fetch tool icon
tool_info = { tool_info = {
'provider_type': node_data.provider_type, 'provider_type': node_data.provider_type.value,
'provider_id': node_data.provider_id 'provider_id': node_data.provider_id
} }

@ -1,16 +1,49 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler
from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeType
from models.account import Tenant from models.account import Tenant
from services.tools.tools_transform_service import ToolTransformService
class PluginInvokeService: class PluginInvokeService:
@classmethod @classmethod
def invoke_tool(cls, user_id: str, tenant: Tenant, def invoke_tool(cls, user_id: str, invoke_from: InvokeFrom, tenant: Tenant,
tool_provider: str, tool_name: str, tool_provider_type: ToolProviderType, tool_provider: str, tool_name: str,
tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]: tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
""" """
Invokes a tool with the given user ID and tool parameters. Invokes a tool with the given user ID and tool parameters.
""" """
tool_runtime = ToolManager.get_tool_runtime(tool_provider_type, provider_id=tool_provider,
tool_name=tool_name, tenant_id=tenant.id,
invoke_from=invoke_from)
response = ToolEngine.plugin_invoke(tool_runtime,
tool_parameters,
user_id,
callback=DifyPluginCallbackHandler())
response = ToolFileMessageTransformer.transform_tool_invoke_messages(response)
return ToolTransformService.transform_messages_to_dict(response)
@classmethod
def invoke_model(cls, user_id: str, tenant: Tenant,
model_provider: str, model_name: str, model_type: ModelType,
model_parameters: dict[str, Any]) -> Union[dict, Generator[ToolInvokeMessage]]:
"""
Invokes a model with the given user ID and model parameters.
"""
@classmethod
def invoke_workflow_node(cls, user_id: str, tenant: Tenant,
node_type: NodeType, node_data: dict[str, Any],
inputs: dict[str, Any]) -> Generator[ToolInvokeMessage]:
"""
Invokes a workflow node with the given user ID and node parameters.
"""

@ -1,5 +1,6 @@
import json import json
import logging import logging
from collections.abc import Generator
from typing import Optional, Union from typing import Optional, Union
from flask import current_app from flask import current_app
@ -9,6 +10,7 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ToolInvokeMessage,
ToolParameter, ToolParameter,
ToolProviderCredentials, ToolProviderCredentials,
ToolProviderType, ToolProviderType,
@ -24,8 +26,8 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ToolTransformService: class ToolTransformService:
@staticmethod @classmethod
def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
""" """
get tool provider icon url get tool provider icon url
""" """
@ -45,8 +47,8 @@ class ToolTransformService:
return '' return ''
@staticmethod @classmethod
def repack_provider(provider: Union[dict, UserToolProvider]): def repack_provider(cls, provider: Union[dict, UserToolProvider]):
""" """
repack provider repack provider
@ -65,8 +67,9 @@ class ToolTransformService:
icon=provider.icon icon=provider.icon
) )
@staticmethod @classmethod
def builtin_provider_to_user_provider( def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController, provider_controller: BuiltinToolProviderController,
db_provider: Optional[BuiltinToolProvider], db_provider: Optional[BuiltinToolProvider],
decrypt_credentials: bool = True, decrypt_credentials: bool = True,
@ -126,8 +129,9 @@ class ToolTransformService:
return result return result
@staticmethod @classmethod
def api_provider_to_controller( def api_provider_to_controller(
cls,
db_provider: ApiToolProvider, db_provider: ApiToolProvider,
) -> ApiToolProviderController: ) -> ApiToolProviderController:
""" """
@ -142,8 +146,9 @@ class ToolTransformService:
return controller return controller
@staticmethod @classmethod
def workflow_provider_to_controller( def workflow_provider_to_controller(
cls,
db_provider: WorkflowToolProvider db_provider: WorkflowToolProvider
) -> WorkflowToolProviderController: ) -> WorkflowToolProviderController:
""" """
@ -179,8 +184,9 @@ class ToolTransformService:
labels=labels or [] labels=labels or []
) )
@staticmethod @classmethod
def api_provider_to_user_provider( def api_provider_to_user_provider(
cls,
provider_controller: ApiToolProviderController, provider_controller: ApiToolProviderController,
db_provider: ApiToolProvider, db_provider: ApiToolProvider,
decrypt_credentials: bool = True, decrypt_credentials: bool = True,
@ -231,8 +237,9 @@ class ToolTransformService:
return result return result
@staticmethod @classmethod
def tool_to_user_tool( def tool_to_user_tool(
cls,
tool: Union[ApiToolBundle, WorkflowTool, Tool], tool: Union[ApiToolBundle, WorkflowTool, Tool],
credentials: dict = None, credentials: dict = None,
tenant_id: str = None, tenant_id: str = None,
@ -288,3 +295,8 @@ class ToolTransformService:
parameters=tool.parameters, parameters=tool.parameters,
labels=labels labels=labels
) )
@classmethod
def transform_messages_to_dict(cls, responses: Generator[ToolInvokeMessage, None, None]):
for response in responses:
yield response.model_dump()
Loading…
Cancel
Save