refactor: tool models

pull/9184/head
Yeuoly 2 years ago
parent 1fa3b9cfd8
commit cf4e9f317e
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -1,5 +1,5 @@
import os import os
from collections.abc import Mapping, Sequence from collections.abc import Iterable, Mapping
from typing import Any, Optional, TextIO, Union from typing import Any, Optional, TextIO, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -55,7 +55,7 @@ class DifyAgentCallbackHandler(BaseModel):
self, self,
tool_name: str, tool_name: str,
tool_inputs: Mapping[str, Any], tool_inputs: Mapping[str, Any],
tool_outputs: Sequence[ToolInvokeMessage], tool_outputs: Iterable[ToolInvokeMessage] | str,
message_id: Optional[str] = None, message_id: Optional[str] = None,
timer: Optional[Any] = None, timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None

@ -1,9 +1,9 @@
import json import json
from collections.abc import Generator, Mapping from collections.abc import Generator, Iterable
from copy import deepcopy from copy import deepcopy
from datetime import datetime, timezone from datetime import datetime, timezone
from mimetypes import guess_type from mimetypes import guess_type
from typing import Any, Optional, Union from typing import Any, Optional, Union, cast
from yarl import URL from yarl import URL
@ -40,7 +40,7 @@ class ToolEngine:
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom, user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler, agent_tool_callback: DifyAgentCallbackHandler,
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None
) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
""" """
Agent invokes the tool with the given arguments. Agent invokes the tool with the given arguments.
""" """
@ -67,9 +67,9 @@ class ToolEngine:
) )
messages = ToolEngine._invoke(tool, tool_parameters, user_id) messages = ToolEngine._invoke(tool, tool_parameters, user_id)
invocation_meta_dict = {'meta': None} invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]): def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]):
for message in messages: for message in messages:
if isinstance(message, ToolInvokeMeta): if isinstance(message, ToolInvokeMeta):
invocation_meta_dict['meta'] = message invocation_meta_dict['meta'] = message
@ -136,7 +136,7 @@ class ToolEngine:
return error_response, [], ToolInvokeMeta.error_instance(error_response) return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod @staticmethod
def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], def workflow_invoke(tool: Tool, tool_parameters: dict[str, Any],
user_id: str, user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int, workflow_call_depth: int,
@ -156,6 +156,7 @@ class ToolEngine:
if tool.runtime and tool.runtime.runtime_parameters: if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters) response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters)
# hit the callback handler # hit the callback handler
@ -204,6 +205,9 @@ class ToolEngine:
""" """
Invoke the tool with the given arguments. Invoke the tool with the given arguments.
""" """
if not tool.runtime:
raise ValueError("missing runtime in tool")
started_at = datetime.now(timezone.utc) started_at = datetime.now(timezone.utc)
meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={ meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={
'tool_name': tool.identity.name, 'tool_name': tool.identity.name,
@ -223,42 +227,42 @@ class ToolEngine:
yield meta yield meta
@staticmethod @staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: def _convert_tool_response_to_str(tool_response: Generator[ToolInvokeMessage, None, None]) -> str:
""" """
Handle tool response Handle tool response
""" """
result = '' result = ''
for response in tool_response: for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT: if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message result += cast(ToolInvokeMessage.TextMessage, response.message).text
elif response.type == ToolInvokeMessage.MessageType.LINK: elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please tell user to check it." result += f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}. please tell user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE: response.type == ToolInvokeMessage.MessageType.IMAGE:
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
elif response.type == ToolInvokeMessage.MessageType.JSON: elif response.type == ToolInvokeMessage.MessageType.JSON:
result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." result += f"tool response: {json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)}."
else: else:
result += f"tool response: {response.message}." result += f"tool response: {response.message}."
return result return result
@staticmethod @staticmethod
def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: def _extract_tool_response_binary(tool_response: Generator[ToolInvokeMessage, None, None]) -> Generator[ToolInvokeMessageBinary, None, None]:
""" """
Extract tool response binary Extract tool response binary
""" """
result = []
for response in tool_response: for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE: response.type == ToolInvokeMessage.MessageType.IMAGE:
mimetype = None mimetype = None
if not response.meta:
raise ValueError("missing meta data")
if response.meta.get('mime_type'): if response.meta.get('mime_type'):
mimetype = response.meta.get('mime_type') mimetype = response.meta.get('mime_type')
else: else:
try: try:
url = URL(response.message) url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
extension = url.suffix extension = url.suffix
guess_type_result, _ = guess_type(f'a{extension}') guess_type_result, _ = guess_type(f'a{extension}')
if guess_type_result: if guess_type_result:
@ -269,35 +273,36 @@ class ToolEngine:
if not mimetype: if not mimetype:
mimetype = 'image/jpeg' mimetype = 'image/jpeg'
result.append(ToolInvokeMessageBinary( yield ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'image/jpeg'), mimetype=response.meta.get('mime_type', 'image/jpeg'),
url=response.message, url=cast(ToolInvokeMessage.TextMessage, response.message).text,
save_as=response.save_as, save_as=response.save_as,
)) )
elif response.type == ToolInvokeMessage.MessageType.BLOB: elif response.type == ToolInvokeMessage.MessageType.BLOB:
result.append(ToolInvokeMessageBinary( if not response.meta:
raise ValueError("missing meta data")
yield ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'), mimetype=response.meta.get('mime_type', 'octet/stream'),
url=response.message, url=cast(ToolInvokeMessage.TextMessage, response.message).text,
save_as=response.save_as, save_as=response.save_as,
)) )
elif response.type == ToolInvokeMessage.MessageType.LINK: elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta # check if there is a mime type in meta
if response.meta and 'mime_type' in response.meta: if response.meta and 'mime_type' in response.meta:
result.append(ToolInvokeMessageBinary( yield ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream', mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
url=response.message, url=cast(ToolInvokeMessage.TextMessage, response.message).text,
save_as=response.save_as, save_as=response.save_as,
)) )
return result
@staticmethod @staticmethod
def _create_message_files( def _create_message_files(
tool_messages: list[ToolInvokeMessageBinary], tool_messages: Iterable[ToolInvokeMessageBinary],
agent_message: Message, agent_message: Message,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
user_id: str user_id: str
) -> list[tuple[Any, str]]: ) -> list[tuple[MessageFile, str]]:
""" """
Create message file Create message file

@ -1,4 +1,4 @@
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Sequence
from os import path from os import path
from typing import Any, cast from typing import Any, cast
@ -100,7 +100,7 @@ class ToolNode(BaseNode):
variable_pool: VariablePool, variable_pool: VariablePool,
node_data: ToolNodeData, node_data: ToolNodeData,
for_log: bool = False, for_log: bool = False,
) -> Mapping[str, Any]: ) -> dict[str, Any]:
""" """
Generate parameters based on the given tool parameters, variable pool, and node data. Generate parameters based on the given tool parameters, variable pool, and node data.
@ -110,7 +110,7 @@ class ToolNode(BaseNode):
node_data (ToolNodeData): The data associated with the tool node. node_data (ToolNodeData): The data associated with the tool node.
Returns: Returns:
Mapping[str, Any]: A dictionary containing the generated parameters. dict[str, Any]: A dictionary containing the generated parameters.
""" """
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}

@ -0,0 +1,5 @@
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass

@ -14,6 +14,7 @@ from core.file.tool_file_parser import ToolFileParser
from core.file.upload_file_parser import UploadFileParser from core.file.upload_file_parser import UploadFileParser
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import generate_string from libs.helper import generate_string
from models.base import Base
from .account import Account, Tenant from .account import Account, Tenant
from .types import StringUUID from .types import StringUUID
@ -211,7 +212,7 @@ class App(db.Model):
return tags if tags else [] return tags if tags else []
class AppModelConfig(db.Model): class AppModelConfig(Base):
__tablename__ = 'app_model_configs' __tablename__ = 'app_model_configs'
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='app_model_config_pkey'), db.PrimaryKeyConstraint('id', name='app_model_config_pkey'),
@ -550,6 +551,9 @@ class Conversation(db.Model):
else: else:
app_model_config = db.session.query(AppModelConfig).filter( app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == self.app_model_config_id).first() AppModelConfig.id == self.app_model_config_id).first()
if not app_model_config:
raise ValueError("app config not found")
model_config = app_model_config.to_dict() model_config = app_model_config.to_dict()
@ -640,7 +644,7 @@ class Conversation(db.Model):
return self.override_model_configs is not None return self.override_model_configs is not None
class Message(db.Model): class Message(Base):
__tablename__ = 'messages' __tablename__ = 'messages'
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='message_pkey'), db.PrimaryKeyConstraint('id', name='message_pkey'),
@ -932,7 +936,7 @@ class MessageFeedback(db.Model):
return account return account
class MessageFile(db.Model): class MessageFile(Base):
__tablename__ = 'message_files' __tablename__ = 'message_files'
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='message_file_pkey'), db.PrimaryKeyConstraint('id', name='message_file_pkey'),
@ -940,15 +944,15 @@ class MessageFile(db.Model):
db.Index('message_file_created_by_idx', 'created_by') db.Index('message_file_created_by_idx', 'created_by')
) )
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
message_id = db.Column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False) type: Mapped[str] = mapped_column(db.String(255), nullable=False)
transfer_method = db.Column(db.String(255), nullable=False) transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
url = db.Column(db.Text, nullable=True) url: Mapped[str] = mapped_column(db.Text, nullable=True)
belongs_to = db.Column(db.String(255), nullable=True) belongs_to: Mapped[str] = mapped_column(db.String(255), nullable=True)
upload_file_id = db.Column(StringUUID, nullable=True) upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
created_by_role = db.Column(db.String(255), nullable=False) created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

@ -1,12 +1,13 @@
import json import json
from sqlalchemy import ForeignKey from sqlalchemy import ForeignKey
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from extensions.ext_database import db from extensions.ext_database import db
from models.base import Base
from .model import Account, App, Tenant from .model import Account, App, Tenant
from .types import StringUUID from .types import StringUUID
@ -277,9 +278,6 @@ class ToolConversationVariables(db.Model):
@property @property
def variables(self) -> dict: def variables(self) -> dict:
return json.loads(self.variables_str) return json.loads(self.variables_str)
class Base(DeclarativeBase):
pass
class ToolFile(Base): class ToolFile(Base):
""" """

Loading…
Cancel
Save