refactor: tool

pull/9184/head
Yeuoly 2 years ago
parent 3c1d32e3ac
commit 91cb80f795
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -2,7 +2,6 @@ import json
import logging import logging
import uuid import uuid
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime, timezone
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity from core.agent.entities import AgentEntity, AgentToolEntity
@ -23,6 +22,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
@ -31,18 +31,15 @@ from core.model_runtime.entities.message_entities import (
) )
from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolParameter, ToolParameter,
ToolRuntimeVariablePool,
) )
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Conversation, Message, MessageAgentThought from models.model import Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,11 +56,9 @@ class BaseAgentRunner(AppRunner):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message: Message, message: Message,
user_id: str, user_id: str,
model_instance: ModelInstance,
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None, prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None: ) -> None:
""" """
Agent runner Agent runner
@ -93,8 +88,6 @@ class BaseAgentRunner(AppRunner):
self.user_id = user_id self.user_id = user_id
self.memory = memory self.memory = memory
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance self.model_instance = model_instance
# init callback # init callback
@ -162,11 +155,10 @@ class BaseAgentRunner(AppRunner):
agent_tool=tool, agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
) )
tool_entity.load_variables(self.variables_pool) assert tool_entity.entity.description
message_tool = PromptMessageTool( message_tool = PromptMessageTool(
name=tool.tool_name, name=tool.tool_name,
description=tool_entity.description.llm, description=tool_entity.entity.description.llm,
parameters={ parameters={
"type": "object", "type": "object",
"properties": {}, "properties": {},
@ -201,9 +193,11 @@ class BaseAgentRunner(AppRunner):
""" """
convert dataset retriever tool to prompt message tool convert dataset retriever tool to prompt message tool
""" """
assert tool.entity.description
prompt_tool = PromptMessageTool( prompt_tool = PromptMessageTool(
name=tool.identity.name, name=tool.entity.identity.name,
description=tool.description.llm, description=tool.entity.description.llm,
parameters={ parameters={
"type": "object", "type": "object",
"properties": {}, "properties": {},
@ -232,7 +226,7 @@ class BaseAgentRunner(AppRunner):
tool_instances = {} tool_instances = {}
prompt_messages_tools = [] prompt_messages_tools = []
for tool in self.app_config.agent.tools if self.app_config.agent else []: for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
try: try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception: except Exception:
@ -249,7 +243,7 @@ class BaseAgentRunner(AppRunner):
# save prompt tool # save prompt tool
prompt_messages_tools.append(prompt_tool) prompt_messages_tools.append(prompt_tool)
# save tool entity # save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool tool_instances[dataset_tool.entity.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools return tool_instances, prompt_messages_tools
@ -328,25 +322,29 @@ class BaseAgentRunner(AppRunner):
def save_agent_thought( def save_agent_thought(
self, self,
agent_thought: MessageAgentThought, agent_thought: MessageAgentThought,
tool_name: str, tool_name: str | None,
tool_input: Union[str, dict], tool_input: Union[str, dict, None],
thought: str, thought: str | None,
observation: Union[str, dict], observation: Union[str, dict, None],
tool_invoke_meta: Union[str, dict], tool_invoke_meta: Union[str, dict, None],
answer: str, answer: str | None,
messages_ids: list[str], messages_ids: list[str],
llm_usage: LLMUsage = None, llm_usage: LLMUsage | None = None,
) -> MessageAgentThought: ):
""" """
Save agent thought Save agent thought
""" """
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() updated_agent_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
)
if not updated_agent_thought:
raise ValueError("agent thought not found")
if thought is not None: if thought is not None:
agent_thought.thought = thought updated_agent_thought.thought = thought
if tool_name is not None: if tool_name is not None:
agent_thought.tool = tool_name updated_agent_thought.tool = tool_name
if tool_input is not None: if tool_input is not None:
if isinstance(tool_input, dict): if isinstance(tool_input, dict):
@ -355,7 +353,7 @@ class BaseAgentRunner(AppRunner):
except Exception as e: except Exception as e:
tool_input = json.dumps(tool_input) tool_input = json.dumps(tool_input)
agent_thought.tool_input = tool_input updated_agent_thought.tool_input = tool_input
if observation is not None: if observation is not None:
if isinstance(observation, dict): if isinstance(observation, dict):
@ -364,27 +362,27 @@ class BaseAgentRunner(AppRunner):
except Exception as e: except Exception as e:
observation = json.dumps(observation) observation = json.dumps(observation)
agent_thought.observation = observation updated_agent_thought.observation = observation
if answer is not None: if answer is not None:
agent_thought.answer = answer updated_agent_thought.answer = answer
if messages_ids is not None and len(messages_ids) > 0: if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids) updated_agent_thought.message_files = json.dumps(messages_ids)
if llm_usage: if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens updated_agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
agent_thought.message_unit_price = llm_usage.prompt_unit_price updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
agent_thought.answer_token = llm_usage.completion_tokens updated_agent_thought.answer_token = llm_usage.completion_tokens
agent_thought.answer_price_unit = llm_usage.completion_price_unit updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
agent_thought.answer_unit_price = llm_usage.completion_unit_price updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
agent_thought.tokens = llm_usage.total_tokens updated_agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price updated_agent_thought.total_price = llm_usage.total_price
# check if tool labels is not empty # check if tool labels is not empty
labels = agent_thought.tool_labels or {} labels = updated_agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else [] tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
for tool in tools: for tool in tools:
if not tool: if not tool:
continue continue
@ -395,7 +393,7 @@ class BaseAgentRunner(AppRunner):
else: else:
labels[tool] = {"en_US": tool, "zh_Hans": tool} labels[tool] = {"en_US": tool, "zh_Hans": tool}
agent_thought.tool_labels_str = json.dumps(labels) updated_agent_thought.tool_labels_str = json.dumps(labels)
if tool_invoke_meta is not None: if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict): if isinstance(tool_invoke_meta, dict):
@ -404,25 +402,8 @@ class BaseAgentRunner(AppRunner):
except Exception as e: except Exception as e:
tool_invoke_meta = json.dumps(tool_invoke_meta) tool_invoke_meta = json.dumps(tool_invoke_meta)
agent_thought.tool_meta_str = tool_invoke_meta updated_agent_thought.tool_meta_str = tool_invoke_meta
db.session.commit()
db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
db_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit() db.session.commit()
db.session.close() db.session.close()
@ -515,6 +496,7 @@ class BaseAgentRunner(AppRunner):
files = message.message_files files = message.message_files
if files: if files:
assert message.app_model_config
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config: if file_extra_config:
@ -525,7 +507,7 @@ class BaseAgentRunner(AppRunner):
if not file_objs: if not file_objs:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)
else: else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)] prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs: for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content) prompt_message_contents.append(file_obj.prompt_message_content)

@ -1,6 +1,6 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator from collections.abc import Generator, Mapping, Sequence
from typing import Optional, Union from typing import Optional, Union
from core.agent.base_agent_runner import BaseAgentRunner from core.agent.base_agent_runner import BaseAgentRunner
@ -12,6 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageTool,
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
@ -26,11 +27,11 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC): class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True _is_first_iteration = True
_ignore_observation_providers = ["wenxin"] _ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] = None _historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit] = None _agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str = None _instruction: str
_query: str = None _query: str
_prompt_messages_tools: list[PromptMessage] = None _prompt_messages_tools: Sequence[PromptMessageTool]
def run( def run(
self, self,
@ -41,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
""" """
Run Cot agent application Run Cot agent application
""" """
app_generate_entity = self.application_generate_entity app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity) self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query) self._init_react_state(query)
@ -53,9 +55,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
app_generate_entity.model_conf.stop.append("Observation") app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config app_config = self.app_config
assert app_config.agent
# init instruction # init instruction
inputs = inputs or {} inputs = inputs or {}
assert app_config.prompt_template.simple_prompt_template
instruction = app_config.prompt_template.simple_prompt_template instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
@ -63,13 +67,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
tool_instances, self._prompt_messages_tools = self._init_prompt_tools() tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
function_call_state = True function_call_state = True
llm_usage = {"usage": None} llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = "" final_answer = ""
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]: if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage final_llm_usage_dict["usage"] = usage
else: else:
@ -115,10 +120,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[], callbacks=[],
) )
# check llm result
if not chunks:
raise ValueError("failed to invoke llm")
usage_dict = {} usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit( scratchpad = AgentScratchpadUnit(
@ -139,11 +140,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if isinstance(chunk, AgentScratchpadUnit.Action): if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk action = chunk
# detect action # detect action
assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump()) scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump()) scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action scratchpad.action = action
else: else:
assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk scratchpad.agent_response += chunk
assert scratchpad.thought is not None
scratchpad.thought += chunk scratchpad.thought += chunk
yield LLMResultChunk( yield LLMResultChunk(
model=self.model_config.model, model=self.model_config.model,
@ -152,6 +156,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
) )
assert scratchpad.thought is not None
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad) self._agent_scratchpad.append(scratchpad)
@ -168,7 +173,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_invoke_meta={}, tool_invoke_meta={},
thought=scratchpad.thought, thought=scratchpad.thought,
observation="", observation="",
answer=scratchpad.agent_response, answer=scratchpad.agent_response or "",
messages_ids=[], messages_ids=[],
llm_usage=usage_dict["usage"], llm_usage=usage_dict["usage"],
) )
@ -248,7 +253,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
messages_ids=[], messages_ids=[],
) )
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageEndEvent( QueueMessageEndEvent(
@ -266,7 +270,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
def _handle_invoke_action( def _handle_invoke_action(
self, self,
action: AgentScratchpadUnit.Action, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool], tool_instances: Mapping[str, Tool],
message_file_ids: list[str], message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]: ) -> tuple[str, ToolInvokeMeta]:
@ -307,15 +311,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files # publish files
for message_file_id, save_as in message_files: for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file # publish message file
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
) )
# add message file ids # add message file ids
message_file_ids.append(message_file_id) message_file_ids.append(message_file_id.id)
return tool_invoke_response, tool_invoke_meta return tool_invoke_response, tool_invoke_meta
@ -369,18 +370,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message return message
def _organize_historic_prompt_messages( def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None self, current_session_messages: list[PromptMessage] | None = None
) -> list[PromptMessage]: ) -> list[PromptMessage]:
""" """
organize historic prompt messages organize historic prompt messages
""" """
result: list[PromptMessage] = [] result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = [] scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages: for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage): if isinstance(message, AssistantPromptMessage):
if not current_scratchpad: if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit( current_scratchpad = AgentScratchpadUnit(
agent_response=message.content, agent_response=message.content,
thought=message.content or "I am thinking about how to help you", thought=message.content or "I am thinking about how to help you",
@ -400,6 +402,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
pass pass
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
if current_scratchpad: if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage): elif isinstance(message, UserPromptMessage):
if scratchpads: if scratchpads:

@ -4,6 +4,7 @@ from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContent,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
@ -16,6 +17,9 @@ class CotChatAgentRunner(CotAgentRunner):
""" """
Organize system prompt Organize system prompt
""" """
assert self.app_config.agent
assert self.app_config.agent.prompt
prompt_entity = self.app_config.agent.prompt prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt first_prompt = prompt_entity.first_prompt
@ -27,12 +31,12 @@ class CotChatAgentRunner(CotAgentRunner):
return SystemPromptMessage(content=system_prompt) return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
""" """
Organize user query Organize user query
""" """
if self.files: if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)] prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
for file_obj in self.files: for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content) prompt_message_contents.append(file_obj.prompt_message_content)
@ -57,8 +61,10 @@ class CotChatAgentRunner(CotAgentRunner):
assistant_message = AssistantPromptMessage(content="") assistant_message = AssistantPromptMessage(content="")
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}" assistant_message.content += f"Final Answer: {unit.agent_response}"
else: else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n" assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str: if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n" assistant_message.content += f"Action: {unit.action_str}\n\n"

@ -2,7 +2,7 @@ import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from typing import Any, Union from typing import Any, Optional, Union
from core.agent.base_agent_runner import BaseAgentRunner from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_queue_manager import PublishFrom
@ -11,6 +11,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageContentType, PromptMessageContentType,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
@ -38,18 +39,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools() tool_instances, prompt_messages_tools = self._init_prompt_tools()
assert app_config.agent
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = True function_call_state = True
llm_usage = {"usage": None} llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = "" final_answer = ""
# get tracing instance # get tracing instance
trace_manager = app_generate_entity.trace_manager trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]: if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage final_llm_usage_dict["usage"] = usage
else: else:
@ -99,7 +102,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = None current_llm_usage = None
if self.stream_tool_call: if isinstance(chunks, Generator):
is_first_chunk = True is_first_chunk = True
for chunk in chunks: for chunk in chunks:
if is_first_chunk: if is_first_chunk:
@ -133,7 +136,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
yield chunk yield chunk
else: else:
result: LLMResult = chunks result = chunks
# check if there is any tool call # check if there is any tool call
if self.check_blocking_tool_calls(result): if self.check_blocking_tool_calls(result):
function_call_state = True function_call_state = True
@ -236,15 +239,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
) )
# publish files # publish files
for message_file_id, save_as in message_files: for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file # publish message file
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
) )
# add message file ids # add message file ids
message_file_ids.append(message_file_id) message_file_ids.append(message_file_id.id)
tool_response = { tool_response = {
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
@ -290,7 +290,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1 iteration_step += 1
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish( self.queue_manager.publish(
QueueMessageEndEvent( QueueMessageEndEvent(
@ -321,9 +320,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True return True
return False return False
def extract_tool_calls( def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
""" """
Extract tool calls from llm result chunk Extract tool calls from llm result chunk
@ -346,7 +343,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
""" """
Extract blocking tool calls from llm result Extract blocking tool calls from llm result
@ -370,7 +367,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls return tool_calls
def _init_system_message( def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None self, prompt_template: str, prompt_messages: list[PromptMessage]
) -> list[PromptMessage]: ) -> list[PromptMessage]:
""" """
Initialize system message Initialize system message
@ -385,12 +382,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return prompt_messages return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
""" """
Organize user query Organize user query
""" """
if self.files: if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)] prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
for file_obj in self.files: for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content) prompt_message_contents.append(file_obj.prompt_message_content)

@ -16,10 +16,8 @@ from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -174,14 +172,6 @@ class AgentChatAppRunner(AppRunner):
agent_entity = app_config.agent agent_entity = app_config.agent
# load tool variables
tool_conversation_variables = self._load_tool_variables(
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
)
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
# init model instance # init model instance
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
@ -234,8 +224,6 @@ class AgentChatAppRunner(AppRunner):
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
memory=memory, memory=memory,
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance, model_instance=model_instance,
) )
@ -253,50 +241,6 @@ class AgentChatAppRunner(AppRunner):
agent=True, agent=True,
) )
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)
if tool_variables:
# save tool variables to session, so that we can update it later
db.session.add(tool_variables)
else:
# create new tool variables
tool_variables = ToolConversationVariables(
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tenant_id,
variables_str="[]",
)
db.session.add(tool_variables)
db.session.commit()
return tool_variables
def _convert_db_variables_to_tool_variables(
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
"""
convert db variables to tool variables
"""
return ToolRuntimeVariablePool(
**{
"conversation_id": db_variables.conversation_id,
"user_id": db_variables.user_id,
"tenant_id": db_variables.tenant_id,
"pool": db_variables.variables,
}
)
def _get_usage_of_all_agent_thoughts( def _get_usage_of_all_agent_thoughts(
self, model_config: ModelConfigWithCredentialsEntity, message: Message self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage: ) -> LLMUsage:

@ -1,7 +1,7 @@
import logging import logging
import os import os
from collections.abc import Callable, Generator, Sequence from collections.abc import Callable, Generator, Sequence
from typing import IO, Optional, Union, cast from typing import IO, Literal, Optional, Union, cast, overload
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.entities.provider_entities import ModelLoadBalancingConfiguration
@ -97,6 +97,42 @@ class ModelInstance:
return None return None
@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: Literal[True] = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Generator: ...
@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: Literal[False] = False,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> LLMResult: ...
@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]: ...
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],

@ -1,72 +1,34 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from enum import Enum from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator from core.tools.__base.tool_runtime import ToolRuntime
from pydantic_core.core_schema import ValidationInfo
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolDescription, ToolEntity,
ToolIdentity,
ToolInvokeFrom,
ToolInvokeMessage, ToolInvokeMessage,
ToolParameter, ToolParameter,
ToolProviderType, ToolProviderType,
ToolRuntimeImageVariable,
ToolRuntimeVariable,
ToolRuntimeVariablePool,
) )
from core.tools.tool_file_manager import ToolFileManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.tool_parameter_converter import ToolParameterConverter
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
class Tool(BaseModel, ABC): class Tool(ABC):
identity: ToolIdentity """
parameters: list[ToolParameter] = Field(default_factory=list) The base class of a tool
description: Optional[ToolDescription] = None """
is_team_authorization: bool = False
# pydantic configs entity: ToolEntity
model_config = ConfigDict(protected_namespaces=()) runtime: ToolRuntime
@field_validator("parameters", mode="before") def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
@classmethod self.entity = entity
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: self.runtime = runtime
return v or []
class Runtime(BaseModel): def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
"""
Meta data of a tool call processing
"""
def __init__(self, **data: Any):
super().__init__(**data)
if not self.runtime_parameters:
self.runtime_parameters = {}
tenant_id: Optional[str] = None
tool_id: Optional[str] = None
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: Optional[dict[str, Any]] = None
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
runtime: Optional[Runtime] = None
variables: Optional[ToolRuntimeVariablePool] = None
def __init__(self, **data: Any):
super().__init__(**data)
class VariableKey(Enum):
IMAGE = "image"
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
""" """
fork a new tool with meta data fork a new tool with meta data
@ -74,10 +36,8 @@ class Tool(BaseModel, ABC):
:return: the new tool :return: the new tool
""" """
return self.__class__( return self.__class__(
identity=self.identity.model_copy() if self.identity else None, entity=self.entity.model_copy(),
parameters=self.parameters.copy() if self.parameters else None, runtime=runtime,
description=self.description.model_copy() if self.description else None,
runtime=Tool.Runtime(**runtime),
) )
@abstractmethod @abstractmethod
@ -88,112 +48,6 @@ class Tool(BaseModel, ABC):
:return: the tool provider type :return: the tool provider type
""" """
def load_variables(self, variables: ToolRuntimeVariablePool):
"""
load variables from database
:param conversation_id: the conversation id
"""
self.variables = variables
def set_image_variable(self, variable_name: str, image_key: str) -> None:
"""
set an image variable
"""
if not self.variables:
return
self.variables.set_file(self.identity.name, variable_name, image_key)
def set_text_variable(self, variable_name: str, text: str) -> None:
"""
set a text variable
"""
if not self.variables:
return
self.variables.set_text(self.identity.name, variable_name, text)
def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
"""
get a variable
:param name: the name of the variable
:return: the variable
"""
if not self.variables:
return None
if isinstance(name, Enum):
name = name.value
for variable in self.variables.pool:
if variable.name == name:
return variable
return None
def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
"""
get the default image variable
:return: the image variable
"""
if not self.variables:
return None
return self.get_variable(self.VariableKey.IMAGE)
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
"""
get a variable file
:param name: the name of the variable
:return: the variable file
"""
variable = self.get_variable(name)
if not variable:
return None
if not isinstance(variable, ToolRuntimeImageVariable):
return None
message_file_id = variable.value
# get file binary
file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
if not file_binary:
return None
return file_binary[0]
def list_variables(self) -> list[ToolRuntimeVariable]:
"""
list all variables
:return: the variables
"""
if not self.variables:
return []
return self.variables.pool
def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
"""
list all image variables
:return: the image variables
"""
if not self.variables:
return []
result = []
for variable in self.variables.pool:
if variable.name.startswith(self.VariableKey.IMAGE.value):
result.append(variable)
return result
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]: def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
if self.runtime and self.runtime.runtime_parameters: if self.runtime and self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters) tool_parameters.update(self.runtime.runtime_parameters)
@ -227,7 +81,7 @@ class Tool(BaseModel, ABC):
""" """
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
result = deepcopy(tool_parameters) result = deepcopy(tool_parameters)
for parameter in self.parameters or []: for parameter in self.entity.parameters:
if parameter.name in tool_parameters: if parameter.name in tool_parameters:
result[parameter.name] = ToolParameterConverter.cast_parameter_by_type( result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
tool_parameters[parameter.name], parameter.type tool_parameters[parameter.name], parameter.type
@ -241,15 +95,6 @@ class Tool(BaseModel, ABC):
) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
pass pass
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
"""
validate the credentials
:param credentials: the credentials
:param parameters: the parameters
"""
pass
def get_runtime_parameters(self) -> list[ToolParameter]: def get_runtime_parameters(self) -> list[ToolParameter]:
""" """
get the runtime parameters get the runtime parameters
@ -258,7 +103,7 @@ class Tool(BaseModel, ABC):
:return: the runtime parameters :return: the runtime parameters
""" """
return self.parameters or [] return self.entity.parameters
def get_all_runtime_parameters(self) -> list[ToolParameter]: def get_all_runtime_parameters(self) -> list[ToolParameter]:
""" """
@ -266,7 +111,7 @@ class Tool(BaseModel, ABC):
:return: all runtime parameters :return: all runtime parameters
""" """
parameters = self.parameters or [] parameters = self.entity.parameters
parameters = parameters.copy() parameters = parameters.copy()
user_parameters = self.get_runtime_parameters() or [] user_parameters = self.get_runtime_parameters() or []
user_parameters = user_parameters.copy() user_parameters = user_parameters.copy()
@ -274,20 +119,16 @@ class Tool(BaseModel, ABC):
# override parameters # override parameters
for parameter in user_parameters: for parameter in user_parameters:
# check if parameter in tool parameters # check if parameter in tool parameters
found = False
for tool_parameter in parameters: for tool_parameter in parameters:
if tool_parameter.name == parameter.name: if tool_parameter.name == parameter.name:
found = True # override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
break break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else: else:
# add new parameter # add new parameter
parameters.append(parameter) parameters.append(parameter)

@ -1,23 +1,22 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolProviderIdentity, ToolProviderEntity,
ToolProviderType, ToolProviderType,
) )
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
class ToolProviderController(BaseModel, ABC): class ToolProviderController(ABC):
identity: ToolProviderIdentity entity: ToolProviderEntity
tools: list[Tool] = Field(default_factory=list) tools: list[Tool]
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
model_config = ConfigDict(validate_assignment=True) def __init__(self, entity: ToolProviderEntity) -> None:
self.entity = entity
self.tools = []
def get_credentials_schema(self) -> dict[str, ProviderConfig]: def get_credentials_schema(self) -> dict[str, ProviderConfig]:
""" """
@ -25,7 +24,7 @@ class ToolProviderController(BaseModel, ABC):
:return: the credentials schema :return: the credentials schema
""" """
return self.credentials_schema.copy() return self.entity.credentials_schema.copy()
@abstractmethod @abstractmethod
def get_tool(self, tool_name: str) -> Tool: def get_tool(self, tool_name: str) -> Tool:
@ -51,7 +50,7 @@ class ToolProviderController(BaseModel, ABC):
:param credentials: the credentials of the tool :param credentials: the credentials of the tool
""" """
credentials_schema = self.credentials_schema credentials_schema = self.entity.credentials_schema
if credentials_schema is None: if credentials_schema is None:
return return
@ -62,7 +61,7 @@ class ToolProviderController(BaseModel, ABC):
for credential_name in credentials: for credential_name in credentials:
if credential_name not in credentials_need_to_validate: if credential_name not in credentials_need_to_validate:
raise ToolProviderCredentialValidationError( raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.identity.name}" f"credential {credential_name} not found in provider {self.entity.identity.name}"
) )
# check type # check type

@ -0,0 +1,36 @@
from typing import Any, Optional
from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom
class ToolRuntime(BaseModel):
"""
Meta data of a tool call processing
"""
tenant_id: str
tool_id: Optional[str] = None
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: Optional[dict[str, Any]] = None
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
class FakeToolRuntime(ToolRuntime):
"""
Fake tool runtime for testing
"""
def __init__(self):
super().__init__(
tenant_id="fake_tenant_id",
tool_id="fake_tool_id",
invoke_from=InvokeFrom.DEBUGGER,
tool_invoke_from=ToolInvokeFrom.AGENT,
credentials={},
runtime_parameters={},
)

@ -2,13 +2,12 @@ from abc import abstractmethod
from os import listdir, path from os import listdir, path
from typing import Any from typing import Any
from pydantic import Field
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import ( from core.tools.errors import (
ToolProviderNotFoundError, ToolProviderNotFoundError,
@ -17,10 +16,10 @@ from core.tools.utils.yaml_utils import load_yaml_file
class BuiltinToolProviderController(ToolProviderController): class BuiltinToolProviderController(ToolProviderController):
tools: list[BuiltinTool] = Field(default_factory=list) tools: list[BuiltinTool]
def __init__(self, **data: Any) -> None: def __init__(self, **data: Any) -> None:
if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}: if self.provider_type == ToolProviderType.API:
super().__init__(**data) super().__init__(**data)
return return
@ -37,10 +36,12 @@ class BuiltinToolProviderController(ToolProviderController):
for credential_name in provider_yaml["credentials_for_provider"]: for credential_name in provider_yaml["credentials_for_provider"]:
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
super().__init__(**{ super().__init__(
'identity': provider_yaml['identity'], entity=ToolProviderEntity(
'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {}, identity=provider_yaml["identity"],
}) credentials_schema=provider_yaml.get("credentials_for_provider", {}) or {},
),
)
def _get_builtin_tools(self) -> list[BuiltinTool]: def _get_builtin_tools(self) -> list[BuiltinTool]:
""" """
@ -51,7 +52,7 @@ class BuiltinToolProviderController(ToolProviderController):
if self.tools: if self.tools:
return self.tools return self.tools
provider = self.identity.name provider = self.entity.identity.name
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools") tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
# get all the yaml files in the tool path # get all the yaml files in the tool path
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path))) tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
@ -62,30 +63,36 @@ class BuiltinToolProviderController(ToolProviderController):
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False) tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
# get tool class, import the module # get tool class, import the module
assistant_tool_class = load_single_subclass_from_source( assistant_tool_class: type[BuiltinTool] = load_single_subclass_from_source(
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}", module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
script_path=path.join( script_path=path.join(
path.dirname(path.realpath(__file__)), path.dirname(path.realpath(__file__)),
"builtin_tool", "providers", provider, "tools", f"{tool_name}.py" "builtin_tool",
"providers",
provider,
"tools",
f"{tool_name}.py",
), ),
parent_type=BuiltinTool, parent_type=BuiltinTool,
) )
tool["identity"]["provider"] = provider tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool)) tools.append(assistant_tool_class(
entity=ToolEntity(**tool), runtime=ToolRuntime(tenant_id=""),
))
self.tools = tools self.tools = tools
return tools return tools
def get_credentials_schema(self) -> dict[str, ProviderConfig]: def get_credentials_schema(self) -> dict[str, ProviderConfig]:
""" """
returns the credentials schema of the provider returns the credentials schema of the provider
:return: the credentials schema :return: the credentials schema
""" """
if not self.credentials_schema: if not self.entity.credentials_schema:
return {} return {}
return self.credentials_schema.copy() return self.entity.credentials_schema.copy()
def get_tools(self) -> list[BuiltinTool]: def get_tools(self) -> list[BuiltinTool]:
""" """
@ -94,12 +101,12 @@ class BuiltinToolProviderController(ToolProviderController):
:return: list of tools :return: list of tools
""" """
return self._get_builtin_tools() return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> BuiltinTool | None: def get_tool(self, tool_name: str) -> BuiltinTool | None:
""" """
returns the tool that the provider can provide returns the tool that the provider can provide
""" """
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
@property @property
def need_credentials(self) -> bool: def need_credentials(self) -> bool:
@ -108,7 +115,7 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials :return: whether the provider needs credentials
""" """
return self.credentials_schema is not None and len(self.credentials_schema) != 0 return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
@ -133,8 +140,8 @@ class BuiltinToolProviderController(ToolProviderController):
""" """
returns the labels of the provider returns the labels of the provider
""" """
return self.identity.tags or [] return self.entity.identity.tags or []
def validate_credentials(self, credentials: dict[str, Any]) -> None: def validate_credentials(self, credentials: dict[str, Any]) -> None:
""" """
validate the credentials of the provider validate the credentials of the provider

@ -1,13 +1,8 @@
from typing import Any from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
from core.tools.errors import ToolProviderCredentialValidationError
class QRCodeProvider(BuiltinToolProviderController): class QRCodeProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None: def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try: pass
QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"})
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

@ -1,16 +1,8 @@
from typing import Any from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool
from core.tools.errors import ToolProviderCredentialValidationError
class WikiPediaProvider(BuiltinToolProviderController): class WikiPediaProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None: def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try: pass
CurrentTimeTool().invoke(
user_id="",
tool_parameters={},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

@ -32,9 +32,9 @@ class BuiltinTool(Tool):
# invoke model # invoke model
return ModelInvocationUtils.invoke( return ModelInvocationUtils.invoke(
user_id=user_id, user_id=user_id,
tenant_id=self.runtime.tenant_id or "", tenant_id=self.runtime.tenant_id,
tool_type="builtin", tool_type="builtin",
tool_name=self.identity.name, tool_name=self.entity.identity.name,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
) )
@ -79,6 +79,7 @@ class BuiltinTool(Tool):
stop=[], stop=[],
) )
assert isinstance(summary.message.content, str)
return summary.message.content return summary.message.content
lines = content.split("\n") lines = content.split("\n")

@ -7,6 +7,8 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ToolProviderEntity,
ToolProviderIdentity,
ToolProviderType, ToolProviderType,
) )
from extensions.ext_database import db from extensions.ext_database import db
@ -18,6 +20,11 @@ class ApiToolProviderController(ToolProviderController):
tenant_id: str tenant_id: str
tools: list[ApiTool] = Field(default_factory=list) tools: list[ApiTool] = Field(default_factory=list)
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
super().__init__(entity)
self.provider_id = provider_id
self.tenant_id = tenant_id
@staticmethod @staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
credentials_schema = { credentials_schema = {
@ -64,25 +71,23 @@ class ApiToolProviderController(ToolProviderController):
} }
elif auth_type == ApiProviderAuthType.NONE: elif auth_type == ApiProviderAuthType.NONE:
pass pass
else:
raise ValueError(f"invalid auth type {auth_type}")
user = db_provider.user user = db_provider.user
user_name = user.name if user else "" user_name = user.name if user else ""
return ApiToolProviderController( return ApiToolProviderController(
**{ entity=ToolProviderEntity(
"identity": { identity=ToolProviderIdentity(
"author": user_name, author=user_name,
"name": db_provider.name, name=db_provider.name,
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
"icon": db_provider.icon, icon=db_provider.icon,
}, ),
"credentials_schema": credentials_schema, credentials_schema=credentials_schema,
"provider_id": db_provider.id or "", ),
"tenant_id": db_provider.tenant_id or "", provider_id=db_provider.id or "",
}, tenant_id=db_provider.tenant_id or "",
) )
@property @property
@ -103,7 +108,7 @@ class ApiToolProviderController(ToolProviderController):
"author": tool_bundle.author, "author": tool_bundle.author,
"name": tool_bundle.operation_id, "name": tool_bundle.operation_id,
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
"icon": self.identity.icon, "icon": self.entity.identity.icon,
"provider": self.provider_id, "provider": self.provider_id,
}, },
"description": { "description": {
@ -141,7 +146,7 @@ class ApiToolProviderController(ToolProviderController):
# get tenant api providers # get tenant api providers
db_providers: list[ApiToolProvider] = ( db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name) .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all() .all()
) )
@ -149,7 +154,6 @@ class ApiToolProviderController(ToolProviderController):
for db_provider in db_providers: for db_provider in db_providers:
for tool in db_provider.tools: for tool in db_provider.tools:
assistant_tool = self._parse_tool_bundle(tool) assistant_tool = self._parse_tool_bundle(tool)
assistant_tool.is_team_authorization = True
tools.append(assistant_tool) tools.append(assistant_tool)
self.tools = tools self.tools = tools
@ -166,7 +170,7 @@ class ApiToolProviderController(ToolProviderController):
self.get_tools(self.tenant_id) self.get_tools(self.tenant_id)
for tool in self.tools: for tool in self.tools:
if tool.identity.name == tool_name: if tool.entity.identity.name == tool_name:
return tool return tool
raise ValueError(f"tool {tool_name} not found") raise ValueError(f"tool {tool_name} not found")

@ -8,8 +8,9 @@ import httpx
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
API_TOOL_DEFAULT_TIMEOUT = ( API_TOOL_DEFAULT_TIMEOUT = (
@ -25,7 +26,11 @@ class ApiTool(Tool):
Api tool Api tool
""" """
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime):
super().__init__(entity, runtime)
self.api_bundle = api_bundle
def fork_tool_runtime(self, runtime: ToolRuntime):
""" """
fork a new tool with meta data fork a new tool with meta data
@ -33,11 +38,9 @@ class ApiTool(Tool):
:return: the new tool :return: the new tool
""" """
return self.__class__( return self.__class__(
identity=self.identity.model_copy(), entity=self.entity,
parameters=self.parameters.copy() if self.parameters else [],
description=self.description.model_copy() if self.description else None,
api_bundle=self.api_bundle.model_copy(), api_bundle=self.api_bundle.model_copy(),
runtime=Tool.Runtime(**runtime), runtime=runtime,
) )
def validate_credentials( def validate_credentials(
@ -62,7 +65,7 @@ class ApiTool(Tool):
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
if self.runtime == None: if self.runtime == None:
raise ToolProviderCredentialValidationError("runtime not initialized") raise ToolProviderCredentialValidationError("runtime not initialized")
headers = {} headers = {}
credentials = self.runtime.credentials or {} credentials = self.runtime.credentials or {}

@ -1,10 +1,11 @@
import base64 import base64
from enum import Enum from enum import Enum
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_serializer, field_validator from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -122,14 +123,14 @@ class ToolInvokeMessage(BaseModel):
""" """
if not isinstance(value, dict | list | str | int | float | bool): if not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types and lists are allowed.") raise ValueError("Only basic types and lists are allowed.")
# if stream is true, the value must be a string # if stream is true, the value must be a string
if values.get('stream'): if values.get("stream"):
if not isinstance(value, str): if not isinstance(value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.") raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
return value return value
@field_validator("variable_name", mode="before") @field_validator("variable_name", mode="before")
@classmethod @classmethod
def transform_variable_name(cls, value) -> str: def transform_variable_name(cls, value) -> str:
@ -158,22 +159,20 @@ class ToolInvokeMessage(BaseModel):
meta: dict[str, Any] | None = None meta: dict[str, Any] | None = None
save_as: str = "" save_as: str = ""
@field_validator('message', mode='before') @field_validator("message", mode="before")
@classmethod @classmethod
def decode_blob_message(cls, v): def decode_blob_message(cls, v):
if isinstance(v, dict) and 'blob' in v: if isinstance(v, dict) and "blob" in v:
try: try:
v['blob'] = base64.b64decode(v['blob']) v["blob"] = base64.b64decode(v["blob"])
except Exception: except Exception:
pass pass
return v return v
@field_serializer('message') @field_serializer("message")
def serialize_message(self, v): def serialize_message(self, v):
if isinstance(v, self.BlobMessage): if isinstance(v, self.BlobMessage):
return { return {"blob": base64.b64encode(v.blob).decode("utf-8")}
'blob': base64.b64encode(v.blob).decode('utf-8')
}
return v return v
@ -252,9 +251,9 @@ class ToolParameter(BaseModel):
option_objs = [] option_objs = []
return cls( return cls(
name=name, name=name,
label=I18nObject(en_US='', zh_Hans=''), label=I18nObject(en_US="", zh_Hans=""),
placeholder=None, placeholder=None,
human_description=I18nObject(en_US='', zh_Hans=''), human_description=I18nObject(en_US="", zh_Hans=""),
type=type, type=type,
form=cls.ToolParameterForm.LLM, form=cls.ToolParameterForm.LLM,
llm_description=llm_description, llm_description=llm_description,
@ -275,6 +274,11 @@ class ToolProviderIdentity(BaseModel):
) )
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
class ToolDescription(BaseModel): class ToolDescription(BaseModel):
human: I18nObject = Field(..., description="The description presented to the user") human: I18nObject = Field(..., description="The description presented to the user")
llm: str = Field(..., description="The description presented to the LLM") llm: str = Field(..., description="The description presented to the LLM")
@ -288,131 +292,6 @@ class ToolIdentity(BaseModel):
icon: Optional[str] = None icon: Optional[str] = None
class ToolRuntimeVariableType(Enum):
TEXT = "text"
IMAGE = "image"
class ToolRuntimeVariable(BaseModel):
type: ToolRuntimeVariableType = Field(..., description="The type of the variable")
name: str = Field(..., description="The name of the variable")
position: int = Field(..., description="The position of the variable")
tool_name: str = Field(..., description="The name of the tool")
class ToolRuntimeTextVariable(ToolRuntimeVariable):
value: str = Field(..., description="The value of the variable")
class ToolRuntimeImageVariable(ToolRuntimeVariable):
value: str = Field(..., description="The path of the image")
class ToolRuntimeVariablePool(BaseModel):
conversation_id: str = Field(..., description="The conversation id")
user_id: str = Field(..., description="The user id")
tenant_id: str = Field(..., description="The tenant id of assistant")
pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables")
def __init__(self, **data: Any):
pool = data.get("pool", [])
# convert pool into correct type
for index, variable in enumerate(pool):
if variable["type"] == ToolRuntimeVariableType.TEXT.value:
pool[index] = ToolRuntimeTextVariable(**variable)
elif variable["type"] == ToolRuntimeVariableType.IMAGE.value:
pool[index] = ToolRuntimeImageVariable(**variable)
super().__init__(**data)
def dict(self) -> dict:
return {
"conversation_id": self.conversation_id,
"user_id": self.user_id,
"tenant_id": self.tenant_id,
"pool": [variable.model_dump() for variable in self.pool],
}
def set_text(self, tool_name: str, name: str, value: str) -> None:
"""
set a text variable
"""
for variable in self.pool:
if variable.name == name:
if variable.type == ToolRuntimeVariableType.TEXT:
variable = cast(ToolRuntimeTextVariable, variable)
variable.value = value
return
variable = ToolRuntimeTextVariable(
type=ToolRuntimeVariableType.TEXT,
name=name,
position=len(self.pool),
tool_name=tool_name,
value=value,
)
self.pool.append(variable)
def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None:
"""
set an image variable
:param tool_name: the name of the tool
:param value: the id of the file
"""
# check how many image variables are there
image_variable_count = 0
for variable in self.pool:
if variable.type == ToolRuntimeVariableType.IMAGE:
image_variable_count += 1
if name is None:
name = f"file_{image_variable_count}"
for variable in self.pool:
if variable.name == name:
if variable.type == ToolRuntimeVariableType.IMAGE:
variable = cast(ToolRuntimeImageVariable, variable)
variable.value = value
return
variable = ToolRuntimeImageVariable(
type=ToolRuntimeVariableType.IMAGE,
name=name,
position=len(self.pool),
tool_name=tool_name,
value=value,
)
self.pool.append(variable)
class ModelToolPropertyKey(Enum):
IMAGE_PARAMETER_NAME = "image_parameter_name"
class ModelToolConfiguration(BaseModel):
"""
Model tool configuration
"""
type: str = Field(..., description="The type of the model tool")
model: str = Field(..., description="The model")
label: I18nObject = Field(..., description="The label of the model tool")
properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool")
class ModelToolProviderConfiguration(BaseModel):
"""
Model tool provider configuration
"""
provider: str = Field(..., description="The provider of the model tool")
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
class WorkflowToolParameterConfiguration(BaseModel): class WorkflowToolParameterConfiguration(BaseModel):
""" """
Workflow tool configuration Workflow tool configuration
@ -471,3 +350,17 @@ class ToolInvokeFrom(Enum):
WORKFLOW = "workflow" WORKFLOW = "workflow"
AGENT = "agent" AGENT = "agent"
class ToolEntity(BaseModel):
identity: ToolIdentity
parameters: list[ToolParameter] = Field(default_factory=list)
description: Optional[ToolDescription] = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or []

@ -65,7 +65,7 @@ class ToolEngine:
# invoke the tool # invoke the tool
try: try:
# hit the callback handler # hit the callback handler
agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
messages = ToolEngine._invoke(tool, tool_parameters, user_id) messages = ToolEngine._invoke(tool, tool_parameters, user_id)
invocation_meta_dict: dict[str, ToolInvokeMeta] = {} invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
@ -99,7 +99,7 @@ class ToolEngine:
# hit the callback handler # hit the callback handler
agent_tool_callback.on_tool_end( agent_tool_callback.on_tool_end(
tool_name=tool.identity.name, tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters, tool_inputs=tool_parameters,
tool_outputs=plain_text, tool_outputs=plain_text,
message_id=message.id, message_id=message.id,
@ -112,7 +112,7 @@ class ToolEngine:
error_response = "Please check your tool provider credentials" error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e) agent_tool_callback.on_tool_error(e)
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
error_response = f"there is not a tool named {tool.identity.name}" error_response = f"there is not a tool named {tool.entity.identity.name}"
agent_tool_callback.on_tool_error(e) agent_tool_callback.on_tool_error(e)
except ToolParameterValidationError as e: except ToolParameterValidationError as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters" error_response = f"tool parameters validation error: {e}, please check your tool parameters"
@ -145,7 +145,7 @@ class ToolEngine:
""" """
try: try:
# hit the callback handler # hit the callback handler
workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
if isinstance(tool, WorkflowTool): if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1 tool.workflow_call_depth = workflow_call_depth + 1
@ -158,7 +158,7 @@ class ToolEngine:
# hit the callback handler # hit the callback handler
workflow_tool_callback.on_tool_end( workflow_tool_callback.on_tool_end(
tool_name=tool.identity.name, tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters, tool_inputs=tool_parameters,
tool_outputs=response, tool_outputs=response,
) )
@ -177,13 +177,13 @@ class ToolEngine:
""" """
try: try:
# hit the callback handler # hit the callback handler
callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
response = tool.invoke(user_id, tool_parameters) response = tool.invoke(user_id, tool_parameters)
# hit the callback handler # hit the callback handler
callback.on_tool_end( callback.on_tool_end(
tool_name=tool.identity.name, tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters, tool_inputs=tool_parameters,
tool_outputs=response, tool_outputs=response,
) )
@ -208,11 +208,11 @@ class ToolEngine:
time_cost=0.0, time_cost=0.0,
error=None, error=None,
tool_config={ tool_config={
"tool_name": tool.identity.name, "tool_name": tool.entity.identity.name,
"tool_provider": tool.identity.provider, "tool_provider": tool.entity.identity.provider,
"tool_provider_type": tool.tool_provider_type().value, "tool_provider_type": tool.tool_provider_type().value,
"tool_parameters": deepcopy(tool.runtime.runtime_parameters), "tool_parameters": deepcopy(tool.runtime.runtime_parameters),
"tool_icon": tool.identity.icon, "tool_icon": tool.entity.identity.icon,
}, },
) )
try: try:

@ -6,6 +6,8 @@ from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Union, cast from typing import TYPE_CHECKING, Any, Union, cast
from core.tools.__base.tool_runtime import ToolRuntime
if TYPE_CHECKING: if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity from core.workflow.nodes.tool.entities import ToolEntity
@ -105,12 +107,12 @@ class ToolManager:
return cast( return cast(
BuiltinTool, BuiltinTool,
builtin_tool.fork_tool_runtime( builtin_tool.fork_tool_runtime(
runtime={ runtime=ToolRuntime(
"tenant_id": tenant_id, tenant_id=tenant_id,
"credentials": {}, credentials={},
"invoke_from": invoke_from, invoke_from=invoke_from,
"tool_invoke_from": tool_invoke_from, tool_invoke_from=tool_invoke_from,
} )
), ),
) )
@ -134,7 +136,7 @@ class ToolManager:
tenant_id=tenant_id, tenant_id=tenant_id,
config=controller.get_credentials_schema(), config=controller.get_credentials_schema(),
provider_type=controller.provider_type.value, provider_type=controller.provider_type.value,
provider_identity=controller.identity.name, provider_identity=controller.entity.identity.name,
) )
decrypted_credentials = tool_configuration.decrypt(credentials) decrypted_credentials = tool_configuration.decrypt(credentials)
@ -142,13 +144,13 @@ class ToolManager:
return cast( return cast(
BuiltinTool, BuiltinTool,
builtin_tool.fork_tool_runtime( builtin_tool.fork_tool_runtime(
runtime={ runtime=ToolRuntime(
"tenant_id": tenant_id, tenant_id=tenant_id,
"credentials": decrypted_credentials, credentials=decrypted_credentials,
"runtime_parameters": {}, runtime_parameters={},
"invoke_from": invoke_from, invoke_from=invoke_from,
"tool_invoke_from": tool_invoke_from, tool_invoke_from=tool_invoke_from,
} )
), ),
) )
@ -163,19 +165,19 @@ class ToolManager:
tenant_id=tenant_id, tenant_id=tenant_id,
config=api_provider.get_credentials_schema(), config=api_provider.get_credentials_schema(),
provider_type=api_provider.provider_type.value, provider_type=api_provider.provider_type.value,
provider_identity=api_provider.identity.name, provider_identity=api_provider.entity.identity.name,
) )
decrypted_credentials = tool_configuration.decrypt(credentials) decrypted_credentials = tool_configuration.decrypt(credentials)
return cast( return cast(
ApiTool, ApiTool,
api_provider.get_tool(tool_name).fork_tool_runtime( api_provider.get_tool(tool_name).fork_tool_runtime(
runtime={ runtime=ToolRuntime(
"tenant_id": tenant_id, tenant_id=tenant_id,
"credentials": decrypted_credentials, credentials=decrypted_credentials,
"invoke_from": invoke_from, invoke_from=invoke_from,
"tool_invoke_from": tool_invoke_from, tool_invoke_from=tool_invoke_from,
} )
), ),
) )
elif provider_type == ToolProviderType.WORKFLOW: elif provider_type == ToolProviderType.WORKFLOW:
@ -193,12 +195,12 @@ class ToolManager:
return cast( return cast(
WorkflowTool, WorkflowTool,
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime={ runtime=ToolRuntime(
"tenant_id": tenant_id, tenant_id=tenant_id,
"credentials": {}, credentials={},
"invoke_from": invoke_from, invoke_from=invoke_from,
"tool_invoke_from": tool_invoke_from, tool_invoke_from=tool_invoke_from,
} )
), ),
) )
elif provider_type == ToolProviderType.APP: elif provider_type == ToolProviderType.APP:
@ -336,7 +338,7 @@ class ToolManager:
"providers", "providers",
provider, provider,
"_assets", "_assets",
provider_controller.identity.icon, provider_controller.entity.identity.icon,
) )
# check if the icon exists # check if the icon exists
if not path.exists(absolute_path): if not path.exists(absolute_path):
@ -389,9 +391,9 @@ class ToolManager:
parent_type=BuiltinToolProviderController, parent_type=BuiltinToolProviderController,
) )
provider: BuiltinToolProviderController = provider_class() provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.identity.name] = provider cls._builtin_providers[provider.entity.identity.name] = provider
for tool in provider.get_tools(): for tool in provider.get_tools():
cls._builtin_tools_labels[tool.identity.name] = tool.identity.label cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
yield provider yield provider
except Exception as e: except Exception as e:
@ -466,11 +468,11 @@ class ToolManager:
user_provider = ToolTransformService.builtin_provider_to_user_provider( user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider, provider_controller=provider,
db_provider=find_db_builtin_provider(provider.identity.name), db_provider=find_db_builtin_provider(provider.entity.identity.name),
decrypt_credentials=False, decrypt_credentials=False,
) )
result_providers[provider.identity.name] = user_provider result_providers[provider.entity.identity.name] = user_provider
# get db api providers # get db api providers
@ -589,7 +591,7 @@ class ToolManager:
tenant_id=tenant_id, tenant_id=tenant_id,
config=controller.get_credentials_schema(), config=controller.get_credentials_schema(),
provider_type=controller.provider_type.value, provider_type=controller.provider_type.value,
provider_identity=controller.identity.name, provider_identity=controller.entity.identity.name,
) )
decrypted_credentials = tool_configuration.decrypt(credentials) decrypted_credentials = tool_configuration.decrypt(credentials)

@ -59,12 +59,11 @@ class ProviderConfigEncrypter(BaseModel):
if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data: if field_name in data:
if len(data[field_name]) > 6: if len(data[field_name]) > 6:
data[field_name] = \ data[field_name] = (
data[field_name][:2] + \ data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
'*' * (len(data[field_name]) - 4) + \ )
data[field_name][-2:]
else: else:
data[field_name] = '*' * len(data[field_name]) data[field_name] = "*" * len(data[field_name])
return data return data
@ -75,9 +74,9 @@ class ProviderConfigEncrypter(BaseModel):
return a deep copy of credentials with decrypted values return a deep copy of credentials with decrypted values
""" """
cache = ToolProviderCredentialsCache( cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
identity_id=f'{self.provider_type}.{self.provider_identity}', identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER cache_type=ToolProviderCredentialsCacheType.PROVIDER,
) )
cached_credentials = cache.get() cached_credentials = cache.get()
if cached_credentials: if cached_credentials:
@ -98,14 +97,14 @@ class ProviderConfigEncrypter(BaseModel):
def delete_tool_credentials_cache(self): def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache( cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
identity_id=f'{self.provider_type}.{self.provider_identity}', identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER cache_type=ToolProviderCredentialsCacheType.PROVIDER,
) )
cache.delete() cache.delete()
class ToolParameterConfigurationManager(BaseModel): class ToolParameterConfigurationManager:
""" """
Tool parameter configuration manager Tool parameter configuration manager
""" """
@ -116,6 +115,15 @@ class ToolParameterConfigurationManager(BaseModel):
provider_type: ToolProviderType provider_type: ToolProviderType
identity_id: str identity_id: str
def __init__(
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
) -> None:
self.tenant_id = tenant_id
self.tool_runtime = tool_runtime
self.provider_name = provider_name
self.provider_type = provider_type
self.identity_id = identity_id
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
""" """
deep copy parameters deep copy parameters
@ -127,7 +135,7 @@ class ToolParameterConfigurationManager(BaseModel):
merge parameters merge parameters
""" """
# get tool parameters # get tool parameters
tool_parameters = self.tool_runtime.parameters or [] tool_parameters = self.tool_runtime.entity.parameters or []
# get tool runtime parameters # get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters() or [] runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
# override parameters # override parameters
@ -203,8 +211,8 @@ class ToolParameterConfigurationManager(BaseModel):
""" """
cache = ToolParameterCache( cache = ToolParameterCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
provider=f'{self.provider_type.value}.{self.provider_name}', provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.identity.name, tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER, cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id, identity_id=self.identity_id,
) )
@ -236,8 +244,8 @@ class ToolParameterConfigurationManager(BaseModel):
def delete_tool_parameters_cache(self): def delete_tool_parameters_cache(self):
cache = ToolParameterCache( cache = ToolParameterCache(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
provider=f'{self.provider_type.value}.{self.provider_name}', provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.identity.name, tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER, cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id, identity_id=self.identity_id,
) )

@ -6,9 +6,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolDescription, ToolDescription,
ToolEntity,
ToolIdentity, ToolIdentity,
ToolInvokeMessage, ToolInvokeMessage,
ToolParameter, ToolParameter,
@ -20,11 +22,15 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
class DatasetRetrieverTool(Tool): class DatasetRetrieverTool(Tool):
retrieval_tool: DatasetRetrieverBaseTool retrieval_tool: DatasetRetrieverBaseTool
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
super().__init__(entity, runtime)
self.retrieval_tool = retrieval_tool
@staticmethod @staticmethod
def get_dataset_tools( def get_dataset_tools(
tenant_id: str, tenant_id: str,
dataset_ids: list[str], dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity, retrieve_config: DatasetRetrieveConfigEntity | None,
return_resource: bool, return_resource: bool,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
@ -54,7 +60,7 @@ class DatasetRetrieverTool(Tool):
) )
if retrieval_tools is None or len(retrieval_tools) == 0: if retrieval_tools is None or len(retrieval_tools) == 0:
return [] return []
# restore retrieve strategy # restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode retrieve_config.retrieve_strategy = original_retriever_mode
@ -63,13 +69,14 @@ class DatasetRetrieverTool(Tool):
for retrieval_tool in retrieval_tools: for retrieval_tool in retrieval_tools:
tool = DatasetRetrieverTool( tool = DatasetRetrieverTool(
retrieval_tool=retrieval_tool, retrieval_tool=retrieval_tool,
identity=ToolIdentity( entity=ToolEntity(
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") identity=ToolIdentity(
provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="")
),
parameters=[],
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
), ),
parameters=[], runtime=ToolRuntime(tenant_id=tenant_id),
is_team_authorization=True,
description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description),
runtime=DatasetRetrieverTool.Runtime(),
) )
tools.append(tool) tools.append(tool)
@ -99,7 +106,7 @@ class DatasetRetrieverTool(Tool):
""" """
query = tool_parameters.get("query") query = tool_parameters.get("query")
if not query: if not query:
yield self.create_text_message(text='please input query') yield self.create_text_message(text="please input query")
else: else:
# invoke dataset retriever tool # invoke dataset retriever tool
result = self.retrieval_tool._run(query=query) result = self.retrieval_tool._run(query=query)

@ -6,9 +6,11 @@ from pydantic import Field
from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolDescription, ToolDescription,
ToolEntity,
ToolIdentity, ToolIdentity,
ToolParameter, ToolParameter,
ToolParameterOption, ToolParameterOption,
@ -63,7 +65,7 @@ class WorkflowToolProviderController(ToolProviderController):
@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
""" """
get db provider tool get db provider tool
@ -140,19 +142,23 @@ class WorkflowToolProviderController(ToolProviderController):
raise ValueError("variable not found") raise ValueError("variable not found")
return WorkflowTool( return WorkflowTool(
identity=ToolIdentity( entity=ToolEntity(
author=user.name if user else "", identity=ToolIdentity(
name=db_provider.name, author=user.name if user else "",
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), name=db_provider.name,
provider=self.provider_id, label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
icon=db_provider.icon, provider=self.provider_id,
icon=db_provider.icon,
),
description=ToolDescription(
human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
llm=db_provider.description,
),
parameters=workflow_tool_parameters,
), ),
description=ToolDescription( runtime=ToolRuntime(
human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), tenant_id=db_provider.tenant_id,
llm=db_provider.description,
), ),
parameters=workflow_tool_parameters,
is_team_authorization=True,
workflow_app_id=app.id, workflow_app_id=app.id,
workflow_entities={ workflow_entities={
"app": app, "app": app,
@ -201,7 +207,7 @@ class WorkflowToolProviderController(ToolProviderController):
return None return None
for tool in self.tools: for tool in self.tools:
if tool.identity.name == tool_name: if tool.entity.identity.name == tool_name:
return tool return tool
return None return None

@ -1,12 +1,12 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy
from typing import Any, Optional, Union from typing import Any, Optional, Union
from core.file.file_obj import FileTransferMethod, FileVar from core.file.file_obj import FileTransferMethod, FileVar
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import App, EndUser from models.model import App, EndUser
@ -28,6 +28,26 @@ class WorkflowTool(Tool):
Workflow tool. Workflow tool.
""" """
def __init__(
self,
workflow_app_id: str,
version: str,
workflow_entities: dict[str, Any],
workflow_call_depth: int,
entity: ToolEntity,
runtime: ToolRuntime,
label: str = "Workflow",
thread_pool_id: Optional[str] = None,
):
self.workflow_app_id = workflow_app_id
self.version = version
self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth
self.thread_pool_id = thread_pool_id
self.label = label
super().__init__(entity=entity, runtime=runtime)
def tool_provider_type(self) -> ToolProviderType: def tool_provider_type(self) -> ToolProviderType:
""" """
get the tool provider type get the tool provider type
@ -94,7 +114,7 @@ class WorkflowTool(Tool):
return user return user
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "WorkflowTool": def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
""" """
fork a new tool with meta data fork a new tool with meta data
@ -102,10 +122,8 @@ class WorkflowTool(Tool):
:return: the new tool :return: the new tool
""" """
return self.__class__( return self.__class__(
identity=deepcopy(self.identity), entity=self.entity.model_copy(),
parameters=deepcopy(self.parameters), runtime=runtime,
description=deepcopy(self.description),
runtime=Tool.Runtime(**runtime),
workflow_app_id=self.workflow_app_id, workflow_app_id=self.workflow_app_id,
workflow_entities=self.workflow_entities, workflow_entities=self.workflow_entities,
workflow_call_depth=self.workflow_call_depth, workflow_call_depth=self.workflow_call_depth,

@ -1,207 +0,0 @@
from collections.abc import Mapping
from typing import Optional
from pydantic import Field
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
ToolParameter,
ToolParameterOption,
ToolProviderType,
)
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
}
class WorkflowToolProviderController(ToolProviderController):
provider_id: str
tools: list[WorkflowTool] = Field(default_factory=list)
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
app = db_provider.app
if not app:
raise ValueError("app not found")
controller = WorkflowToolProviderController(
**{
"identity": {
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
"name": db_provider.label,
"label": {"en_US": db_provider.label, "zh_Hans": db_provider.label},
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
"icon": db_provider.icon,
},
"credentials_schema": {},
"provider_id": db_provider.id or "",
}
)
# init tools
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
return controller
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
"""
get db provider tool
:param db_provider: the db provider
:param app: the app
:return: the tool
"""
workflow: Workflow | None = db.session.query(Workflow).filter(
Workflow.app_id == db_provider.app_id,
Workflow.version == db_provider.version
).first()
if not workflow:
raise ValueError("workflow not found")
# fetch start node
graph: Mapping = workflow.graph_dict
features_dict: Mapping = workflow.features_dict
features = WorkflowAppConfigManager.convert_features(
config_dict=features_dict,
app_mode=AppMode.WORKFLOW
)
parameters = db_provider.parameter_configurations
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user
workflow_tool_parameters = []
for parameter in parameters:
variable = fetch_workflow_variable(parameter.name)
if variable:
parameter_type = None
options = []
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
raise ValueError(f"unsupported variable type {variable.type}")
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
if variable.type == VariableEntityType.SELECT and variable.options:
options = [
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in variable.options
]
workflow_tool_parameters.append(
ToolParameter(
name=parameter.name,
label=I18nObject(en_US=variable.label, zh_Hans=variable.label),
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
type=parameter_type,
form=parameter.form,
llm_description=parameter.description,
required=variable.required,
options=options,
default=variable.default,
)
)
elif features.file_upload:
workflow_tool_parameters.append(
ToolParameter(
name=parameter.name,
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
type=ToolParameter.ToolParameterType.FILE,
llm_description=parameter.description,
required=False,
form=parameter.form,
)
)
else:
raise ValueError("variable not found")
return WorkflowTool(
identity=ToolIdentity(
author=user.name if user else "",
name=db_provider.name,
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
provider=self.provider_id,
icon=db_provider.icon,
),
description=ToolDescription(
human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
llm=db_provider.description,
),
parameters=workflow_tool_parameters,
is_team_authorization=True,
workflow_app_id=app.id,
workflow_entities={
"app": app,
"workflow": workflow,
},
version=db_provider.version,
workflow_call_depth=0,
label=db_provider.label,
)
def get_tools(self, tenant_id: str) -> list[WorkflowTool]:
"""
fetch tools from database
:param user_id: the user id
:param tenant_id: the tenant id
:return: the tools
"""
if self.tools is not None:
return self.tools
db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
).first()
if not db_providers:
return []
app = db_providers.app
if not app:
raise ValueError("can not read app of workflow")
self.tools = [self._get_db_provider_tool(db_providers, app)]
return self.tools
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
"""
get tool by name
:param tool_name: the name of the tool
:return: the tool
"""
if self.tools is None:
return None
for tool in self.tools:
if tool.identity.name == tool_name:
return tool
return None

@ -1304,7 +1304,7 @@ class MessageChain(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class MessageAgentThought(db.Model): class MessageAgentThought(Base):
__tablename__ = "message_agent_thoughts" __tablename__ = "message_agent_thoughts"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),

@ -5,6 +5,7 @@ from httpx import get
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -160,7 +161,7 @@ class ApiToolManageService:
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name provider_identity=provider_controller.entity.identity.name
) )
encrypted_credentials = tool_configuration.encrypt(credentials) encrypted_credentials = tool_configuration.encrypt(credentials)
@ -222,6 +223,7 @@ class ApiToolManageService:
return [ return [
ToolTransformService.tool_to_user_tool( ToolTransformService.tool_to_user_tool(
tool_bundle, tool_bundle,
tenant_id=tenant_id,
labels=labels, labels=labels,
) )
for tool_bundle in provider.tools for tool_bundle in provider.tools
@ -291,7 +293,7 @@ class ApiToolManageService:
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name provider_identity=provider_controller.entity.identity.name
) )
original_credentials = tool_configuration.decrypt(provider.credentials) original_credentials = tool_configuration.decrypt(provider.credentials)
@ -410,7 +412,7 @@ class ApiToolManageService:
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name provider_identity=provider_controller.entity.identity.name
) )
decrypted_credentials = tool_configuration.decrypt(credentials) decrypted_credentials = tool_configuration.decrypt(credentials)
# check if the credential has changed, save the original credential # check if the credential has changed, save the original credential
@ -424,10 +426,10 @@ class ApiToolManageService:
# get tool # get tool
tool = provider_controller.get_tool(tool_name) tool = provider_controller.get_tool(tool_name)
tool = tool.fork_tool_runtime( tool = tool.fork_tool_runtime(
runtime={ runtime=ToolRuntime(
"credentials": credentials, credentials=credentials,
"tenant_id": tenant_id, tenant_id=tenant_id,
} )
) )
result = tool.validate_credentials(credentials, parameters) result = tool.validate_credentials(credentials, parameters)
except Exception as e: except Exception as e:

@ -32,7 +32,7 @@ class BuiltinToolManageService:
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
# check if user has added the provider # check if user has added the provider
builtin_provider: BuiltinToolProvider | None = ( builtin_provider: BuiltinToolProvider | None = (
@ -71,7 +71,7 @@ class BuiltinToolManageService:
:return: the list of tool providers :return: the list of tool providers
""" """
provider = ToolManager.get_builtin_provider(provider_name) provider = ToolManager.get_builtin_provider(provider_name)
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) return jsonable_encoder([v for _, v in (provider.entity.credentials_schema or {}).items()])
@staticmethod @staticmethod
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
@ -97,7 +97,7 @@ class BuiltinToolManageService:
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
# get original credentials if exists # get original credentials if exists
@ -159,7 +159,7 @@ class BuiltinToolManageService:
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
credentials = tool_configuration.decrypt(provider_obj.credentials) credentials = tool_configuration.decrypt(provider_obj.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials) credentials = tool_configuration.mask_tool_credentials(credentials)
@ -191,7 +191,7 @@ class BuiltinToolManageService:
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
tool_configuration.delete_tool_credentials_cache() tool_configuration.delete_tool_credentials_cache()
@ -241,7 +241,7 @@ class BuiltinToolManageService:
# convert provider controller to user provider # convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller, provider_controller=provider_controller,
db_provider=find_provider(provider_controller.identity.name), db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True, decrypt_credentials=True,
) )

@ -4,6 +4,7 @@ from typing import Optional, Union
from configs import dify_config from configs import dify_config
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.entities.api_entities import UserTool, UserToolProvider
@ -69,19 +70,19 @@ class ToolTransformService:
convert provider controller to user provider convert provider controller to user provider
""" """
result = UserToolProvider( result = UserToolProvider(
id=provider_controller.identity.name, id=provider_controller.entity.identity.name,
author=provider_controller.identity.author, author=provider_controller.entity.identity.author,
name=provider_controller.identity.name, name=provider_controller.entity.identity.name,
description=I18nObject( description=I18nObject(
en_US=provider_controller.identity.description.en_US, en_US=provider_controller.entity.identity.description.en_US,
zh_Hans=provider_controller.identity.description.zh_Hans, zh_Hans=provider_controller.entity.identity.description.zh_Hans,
pt_BR=provider_controller.identity.description.pt_BR, pt_BR=provider_controller.entity.identity.description.pt_BR,
), ),
icon=provider_controller.identity.icon, icon=provider_controller.entity.identity.icon,
label=I18nObject( label=I18nObject(
en_US=provider_controller.identity.label.en_US, en_US=provider_controller.entity.identity.label.en_US,
zh_Hans=provider_controller.identity.label.zh_Hans, zh_Hans=provider_controller.entity.identity.label.zh_Hans,
pt_BR=provider_controller.identity.label.pt_BR, pt_BR=provider_controller.entity.identity.label.pt_BR,
), ),
type=ToolProviderType.BUILT_IN, type=ToolProviderType.BUILT_IN,
masked_credentials={}, masked_credentials={},
@ -111,7 +112,7 @@ class ToolTransformService:
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
config=provider_controller.get_credentials_schema(), config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name provider_identity=provider_controller.entity.identity.name,
) )
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials) decrypted_credentials = tool_configuration.decrypt(data=credentials)
@ -155,16 +156,16 @@ class ToolTransformService:
""" """
return UserToolProvider( return UserToolProvider(
id=provider_controller.provider_id, id=provider_controller.provider_id,
author=provider_controller.identity.author, author=provider_controller.entity.identity.author,
name=provider_controller.identity.name, name=provider_controller.entity.identity.name,
description=I18nObject( description=I18nObject(
en_US=provider_controller.identity.description.en_US, en_US=provider_controller.entity.identity.description.en_US,
zh_Hans=provider_controller.identity.description.zh_Hans, zh_Hans=provider_controller.entity.identity.description.zh_Hans,
), ),
icon=provider_controller.identity.icon, icon=provider_controller.entity.identity.icon,
label=I18nObject( label=I18nObject(
en_US=provider_controller.identity.label.en_US, en_US=provider_controller.entity.identity.label.en_US,
zh_Hans=provider_controller.identity.label.zh_Hans, zh_Hans=provider_controller.entity.identity.label.zh_Hans,
), ),
type=ToolProviderType.WORKFLOW, type=ToolProviderType.WORKFLOW,
masked_credentials={}, masked_credentials={},
@ -189,7 +190,7 @@ class ToolTransformService:
user = db_provider.user user = db_provider.user
if not user: if not user:
raise ValueError("user not found") raise ValueError("user not found")
username = user.name username = user.name
except Exception as e: except Exception as e:
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}") logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
@ -222,7 +223,7 @@ class ToolTransformService:
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
config=provider_controller.get_credentials_schema(), config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.identity.name provider_identity=provider_controller.entity.identity.name,
) )
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials
@ -236,8 +237,8 @@ class ToolTransformService:
@staticmethod @staticmethod
def tool_to_user_tool( def tool_to_user_tool(
tool: Union[ApiToolBundle, WorkflowTool, Tool], tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str,
credentials: dict | None = None, credentials: dict | None = None,
tenant_id: str | None = None,
labels: list[str] | None = None, labels: list[str] | None = None,
) -> UserTool: ) -> UserTool:
""" """
@ -246,14 +247,14 @@ class ToolTransformService:
if isinstance(tool, Tool): if isinstance(tool, Tool):
# fork tool runtime # fork tool runtime
tool = tool.fork_tool_runtime( tool = tool.fork_tool_runtime(
runtime={ runtime=ToolRuntime(
"credentials": credentials, credentials=credentials,
"tenant_id": tenant_id, tenant_id=tenant_id,
} )
) )
# get tool parameters # get tool parameters
parameters = tool.parameters or [] parameters = tool.entity.parameters or []
# get tool runtime parameters # get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters() or [] runtime_parameters = tool.get_runtime_parameters() or []
# override parameters # override parameters
@ -270,10 +271,10 @@ class ToolTransformService:
current_parameters.append(runtime_parameter) current_parameters.append(runtime_parameter)
return UserTool( return UserTool(
author=tool.identity.author, author=tool.entity.identity.author,
name=tool.identity.name, name=tool.entity.identity.name,
label=tool.identity.label, label=tool.entity.identity.label,
description=tool.description.human if tool.description else I18nObject(en_US=''), description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
parameters=current_parameters, parameters=current_parameters,
labels=labels or [], labels=labels or [],
) )

@ -211,7 +211,9 @@ class WorkflowToolManageService:
ToolTransformService.repack_provider(user_tool_provider) ToolTransformService.repack_provider(user_tool_provider)
user_tool_provider.tools = [ user_tool_provider.tools = [
ToolTransformService.tool_to_user_tool( ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) tool=tool.get_tools(user_id, tenant_id)[0],
labels=labels.get(tool.provider_id, []),
tenant_id=tenant_id,
) )
] ]
result.append(user_tool_provider) result.append(user_tool_provider)
@ -248,7 +250,7 @@ class WorkflowToolManageService:
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first() .first()
) )
return cls._get_workflow_tool(db_tool) return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod @classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
@ -264,10 +266,10 @@ class WorkflowToolManageService:
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first() .first()
) )
return cls._get_workflow_tool(db_tool) return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod @classmethod
def _get_workflow_tool(cls, db_tool: WorkflowToolProvider | None): def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
""" """
Get a workflow tool. Get a workflow tool.
:db_tool: the database tool :db_tool: the database tool
@ -298,7 +300,9 @@ class WorkflowToolManageService:
"description": db_tool.description, "description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations), "parameters": jsonable_encoder(db_tool.parameter_configurations),
"tool": ToolTransformService.tool_to_user_tool( "tool": ToolTransformService.tool_to_user_tool(
tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
), ),
"synced": workflow.version == db_tool.version, "synced": workflow.version == db_tool.version,
"privacy_policy": db_tool.privacy_policy, "privacy_policy": db_tool.privacy_policy,
@ -326,6 +330,8 @@ class WorkflowToolManageService:
return [ return [
ToolTransformService.tool_to_user_tool( ToolTransformService.tool_to_user_tool(
tool=tool.get_tools(db_tool.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) tool=tool.get_tools(db_tool.tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool),
tenant_id=tenant_id,
) )
] ]

@ -1,5 +1,9 @@
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.tool import ApiTool from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
from tests.integration_tests.tools.__mock.http import setup_http_mock from tests.integration_tests.tools.__mock.http import setup_http_mock
tool_bundle = { tool_bundle = {
@ -29,7 +33,13 @@ parameters = {
def test_api_tool(setup_http_mock): def test_api_tool(setup_http_mock):
tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"})) tool = ApiTool(
entity=ToolEntity(
identity=ToolIdentity(provider="", author="", name="", label=I18nObject()),
),
api_bundle=ApiToolBundle(**tool_bundle),
runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}),
)
headers = tool.assembling_request(parameters) headers = tool.assembling_request(parameters)
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters) response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)

Loading…
Cancel
Save