refactor: tool message transformer

pull/9184/head
Yeuoly 2 years ago
parent 4b4741f7ed
commit c28998a6f0
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -96,6 +96,9 @@ class ToolInvokeMessage(BaseModel):
class JsonMessage(BaseModel): class JsonMessage(BaseModel):
json_object: dict json_object: dict
class BlobMessage(BaseModel):
blob: bytes
class MessageType(Enum): class MessageType(Enum):
TEXT = "text" TEXT = "text"
IMAGE = "image" IMAGE = "image"
@ -109,7 +112,7 @@ class ToolInvokeMessage(BaseModel):
""" """
plain text, image url or link url plain text, image url or link url
""" """
message: JsonMessage | TextMessage | None message: JsonMessage | TextMessage | BlobMessage | None
meta: dict[str, Any] | None = None meta: dict[str, Any] | None = None
save_as: str = '' save_as: str = ''
@ -321,7 +324,7 @@ class ToolRuntimeVariablePool(BaseModel):
self.pool.append(variable) self.pool.append(variable)
def set_file(self, tool_name: str, value: str, name: str = None) -> None: def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None:
""" """
set an image variable set an image variable

@ -80,8 +80,8 @@ class ToolFileManager:
def create_file_by_url( def create_file_by_url(
user_id: str, user_id: str,
tenant_id: str, tenant_id: str,
conversation_id: str,
file_url: str, file_url: str,
conversation_id: Optional[str] = None,
) -> ToolFile: ) -> ToolFile:
""" """
create file create file
@ -131,7 +131,7 @@ class ToolFileManager:
:return: the binary of the file, mime type :return: the binary of the file, mime type
""" """
tool_file: ToolFile = ( tool_file: ToolFile | None = (
db.session.query(ToolFile) db.session.query(ToolFile)
.filter( .filter(
ToolFile.id == id, ToolFile.id == id,
@ -155,7 +155,7 @@ class ToolFileManager:
:return: the binary of the file, mime type :return: the binary of the file, mime type
""" """
message_file: MessageFile = ( message_file: MessageFile | None = (
db.session.query(MessageFile) db.session.query(MessageFile)
.filter( .filter(
MessageFile.id == id, MessageFile.id == id,
@ -173,7 +173,7 @@ class ToolFileManager:
tool_file_id = None tool_file_id = None
tool_file: ToolFile = ( tool_file: ToolFile | None = (
db.session.query(ToolFile) db.session.query(ToolFile)
.filter( .filter(
ToolFile.id == tool_file_id, ToolFile.id == tool_file_id,
@ -197,7 +197,7 @@ class ToolFileManager:
:return: the binary of the file, mime type :return: the binary of the file, mime type
""" """
tool_file: ToolFile = ( tool_file: ToolFile | None = (
db.session.query(ToolFile) db.session.query(ToolFile)
.filter( .filter(
ToolFile.id == tool_file_id, ToolFile.id == tool_file_id,

@ -1,8 +1,9 @@
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from mimetypes import guess_extension from mimetypes import guess_extension
from typing import Optional
from core.file.file_obj import FileTransferMethod, FileType from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
@ -13,7 +14,7 @@ class ToolFileMessageTransformer:
def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None], def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None],
user_id: str, user_id: str,
tenant_id: str, tenant_id: str,
conversation_id: str) -> Generator[ToolInvokeMessage, None, None]: conversation_id: Optional[str] = None) -> Generator[ToolInvokeMessage, None, None]:
""" """
Transform tool message and handle file download Transform tool message and handle file download
""" """
@ -25,18 +26,23 @@ class ToolFileMessageTransformer:
elif message.type == ToolInvokeMessage.MessageType.IMAGE: elif message.type == ToolInvokeMessage.MessageType.IMAGE:
# try to download image # try to download image
try: try:
if not conversation_id:
raise
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
file = ToolFileManager.create_file_by_url( file = ToolFileManager.create_file_by_url(
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id, conversation_id=conversation_id,
file_url=message.message
) )
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
yield ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK, type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url, message=ToolInvokeMessage.TextMessage(text=url),
save_as=message.save_as, save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
) )
@ -44,57 +50,67 @@ class ToolFileMessageTransformer:
logger.exception(e) logger.exception(e)
yield ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT, type=ToolInvokeMessage.MessageType.TEXT,
message=f"Failed to download image: {message.message}, you can try to download it yourself.", message=ToolInvokeMessage.TextMessage(
text=f"Failed to download image: {message.message}, you can try to download it yourself."
),
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
save_as=message.save_as, save_as=message.save_as,
) )
elif message.type == ToolInvokeMessage.MessageType.BLOB: elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage # get mime type and save blob to storage
assert message.meta
mimetype = message.meta.get('mime_type', 'octet/stream') mimetype = message.meta.get('mime_type', 'octet/stream')
# if message is str, encode it to bytes # if message is str, encode it to bytes
if isinstance(message.message, str):
message.message = message.message.encode('utf-8') if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
raise ValueError("unexpected message type")
file = ToolFileManager.create_file_by_raw( file = ToolFileManager.create_file_by_raw(
user_id=user_id, tenant_id=tenant_id, user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id, conversation_id=conversation_id,
file_binary=message.message, file_binary=message.message.blob,
mimetype=mimetype mimetype=mimetype
) )
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) extension = guess_extension(file.mimetype) or ".bin"
url = cls.get_tool_file_url(file.id, extension)
# check if file is image # check if file is image
if 'image' in mimetype: if 'image' in mimetype:
yield ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK, type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url, message=ToolInvokeMessage.TextMessage(text=url),
save_as=message.save_as, save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
) )
else: else:
yield ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK, type=ToolInvokeMessage.MessageType.LINK,
message=url, message=ToolInvokeMessage.TextMessage(text=url),
save_as=message.save_as, save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
) )
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
file_var = message.meta.get('file_var') assert message.meta
file_var: FileVar | None = message.meta.get('file_var')
if file_var: if file_var:
if file_var.transfer_method == FileTransferMethod.TOOL_FILE: if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
assert file_var.related_id and file_var.extension
url = cls.get_tool_file_url(file_var.related_id, file_var.extension) url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
if file_var.type == FileType.IMAGE: if file_var.type == FileType.IMAGE:
yield ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK, type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url, message=ToolInvokeMessage.TextMessage(text=url),
save_as=message.save_as, save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
) )
else: else:
yield ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK, type=ToolInvokeMessage.MessageType.LINK,
message=url, message=ToolInvokeMessage.TextMessage(text=url),
save_as=message.save_as, save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
) )

@ -1,4 +1,4 @@
from collections.abc import Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from os import path from os import path
from typing import Any, cast from typing import Any, cast
@ -145,7 +145,7 @@ class ToolNode(BaseNode):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else [] return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): def _convert_tool_messages(self, messages: Generator[ToolInvokeMessage, None, None]):
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
""" """

@ -1,6 +1,7 @@
import json import json
from sqlalchemy import ForeignKey from sqlalchemy import ForeignKey
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
@ -277,7 +278,7 @@ class ToolConversationVariables(db.Model):
def variables(self) -> dict: def variables(self) -> dict:
return json.loads(self.variables_str) return json.loads(self.variables_str)
class ToolFile(db.Model): class ToolFile(DeclarativeBase):
""" """
store the file created by agent store the file created by agent
""" """
@ -288,16 +289,17 @@ class ToolFile(db.Model):
db.Index('tool_file_conversation_id_idx', 'conversation_id'), db.Index('tool_file_conversation_id_idx', 'conversation_id'),
) )
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
# conversation user id # conversation user id
user_id = db.Column(StringUUID, nullable=False) user_id: Mapped[str] = mapped_column(StringUUID)
# tenant id # tenant id
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[StringUUID] = mapped_column(StringUUID)
# conversation id # conversation id
conversation_id = db.Column(StringUUID, nullable=True) conversation_id: Mapped[StringUUID] = mapped_column(nullable=True)
# file key # file key
file_key = db.Column(db.String(255), nullable=False) file_key: Mapped[str] = mapped_column(db.String(255), nullable=False)
# mime type # mime type
mimetype = db.Column(db.String(255), nullable=False) mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
# original url # original url
original_url = db.Column(db.String(2048), nullable=True) original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
Loading…
Cancel
Save