refactor: stream output

pull/9184/head
Yeuoly 2 years ago
parent b0d53c0ac4
commit cf73374c1b
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -1,6 +1,7 @@
from collections.abc import Generator, Iterable, Sequence from collections.abc import Generator, Iterable, Mapping, Sequence
from os import path from os import path
from typing import Any, Mapping, cast from typing import Any, cast
from urllib import response
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@ -13,6 +14,7 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunStreamChunkEvent
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models import WorkflowNodeExecutionStatus from models import WorkflowNodeExecutionStatus
@ -26,7 +28,7 @@ class ToolNode(BaseNode):
_node_data_cls = ToolNodeData _node_data_cls = ToolNodeData
_node_type = NodeType.TOOL _node_type = NodeType.TOOL
def _run(self) -> NodeRunResult: def _run(self) -> Generator[RunEvent]:
""" """
Run the tool node Run the tool node
""" """
@ -45,7 +47,8 @@ class ToolNode(BaseNode):
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
) )
except Exception as e: except Exception as e:
return NodeRunResult( yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs={}, inputs={},
metadata={ metadata={
@ -53,14 +56,25 @@ class ToolNode(BaseNode):
}, },
error=f'Failed to get tool runtime: {str(e)}' error=f'Failed to get tool runtime: {str(e)}'
) )
)
return
# get parameters # get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or [] tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data) parameters = self._generate_parameters(
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True) tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
for_log=True
)
try: try:
messages = ToolEngine.workflow_invoke( message_stream = ToolEngine.workflow_invoke(
tool=tool_runtime, tool=tool_runtime,
tool_parameters=parameters, tool_parameters=parameters,
user_id=self.user_id, user_id=self.user_id,
@ -69,7 +83,8 @@ class ToolNode(BaseNode):
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )
except Exception as e: except Exception as e:
return NodeRunResult( yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={ metadata={
@ -77,22 +92,24 @@ class ToolNode(BaseNode):
}, },
error=f'Failed to invoke tool: {str(e)}', error=f'Failed to invoke tool: {str(e)}',
) )
)
return
# convert tool messages # convert tool messages
plain_text, files, json = self._convert_tool_messages(messages) yield from self._transform_message(message_stream, tool_info, parameters_for_log)
return NodeRunResult( # return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, # status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={ # outputs={
'text': plain_text, # 'text': plain_text,
'files': files, # 'files': files,
'json': json # 'json': json
}, # },
metadata={ # metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info # NodeRunMetadataKey.TOOL_INFO: tool_info
}, # },
inputs=parameters_for_log # inputs=parameters_for_log
) # )
def _generate_parameters( def _generate_parameters(
self, self,
@ -148,48 +165,40 @@ class ToolNode(BaseNode):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else [] return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: Generator[ToolInvokeMessage, None, None]): def _transform_message(self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any]) -> Generator[RunEvent, None, None]:
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
""" """
# transform message and handle file storage # transform message and handle file storage
messages = ToolFileMessageTransformer.transform_tool_invoke_messages( message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages, messages=messages,
user_id=self.user_id, user_id=self.user_id,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
conversation_id=None, conversation_id=None,
) )
result = list(messages) files: list[FileVar] = []
text = ""
# extract plain text and files json: list[dict] = []
files = self._extract_tool_response_binary(result)
plain_text = self._extract_tool_response_text(result)
json = self._extract_tool_response_json(result)
return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]: for message in message_stream:
""" if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
Extract tool response binary message.type == ToolInvokeMessage.MessageType.IMAGE:
""" assert isinstance(message.message, ToolInvokeMessage.TextMessage)
result = [] assert message.meta
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
assert response.meta
url = response.message.text url = message.message.text
ext = path.splitext(url)[1] ext = path.splitext(url)[1]
mimetype = response.meta.get('mime_type', 'image/jpeg') mimetype = message.meta.get('mime_type', 'image/jpeg')
filename = response.save_as or url.split('/')[-1] filename = message.save_as or url.split('/')[-1]
transfer_method = response.meta.get('transfer_method', FileTransferMethod.TOOL_FILE) transfer_method = message.meta.get('transfer_method', FileTransferMethod.TOOL_FILE)
# get tool file id # get tool file id
tool_file_id = url.split('/')[-1].split('.')[0] tool_file_id = url.split('/')[-1].split('.')[0]
result.append(FileVar( files.append(FileVar(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=FileType.IMAGE, type=FileType.IMAGE,
transfer_method=transfer_method, transfer_method=transfer_method,
@ -199,48 +208,54 @@ class ToolNode(BaseNode):
extension=ext, extension=ext,
mime_type=mimetype, mime_type=mimetype,
)) ))
elif response.type == ToolInvokeMessage.MessageType.BLOB: elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id # get tool file id
assert isinstance(response.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert response.meta assert message.meta
tool_file_id = response.message.text.split('/')[-1].split('.')[0] tool_file_id = message.message.text.split('/')[-1].split('.')[0]
result.append(FileVar( files.append(FileVar(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=FileType.IMAGE, type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE, transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id, related_id=tool_file_id,
filename=response.save_as, filename=message.save_as,
extension=path.splitext(response.save_as)[1], extension=path.splitext(message.save_as)[1],
mime_type=response.meta.get('mime_type', 'application/octet-stream'), mime_type=message.meta.get('mime_type', 'application/octet-stream'),
)) ))
elif response.type == ToolInvokeMessage.MessageType.LINK: elif message.type == ToolInvokeMessage.MessageType.TEXT:
pass # TODO:
return result
def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str:
"""
Extract tool response text
"""
result: list[str] = []
for message in tool_response:
if message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, ToolInvokeMessage.TextMessage)
result.append(message.message.text) text += message.message.text + '\n'
yield RunStreamChunkEvent(
chunk_content=message.message.text,
from_variable_selector=[self.node_id, 'text']
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message, ToolInvokeMessage.JsonMessage)
json.append(message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK: elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, ToolInvokeMessage.TextMessage)
result.append(f'Link: {message.message.text}') stream_text = f'Link: {message.message.text}\n'
text += stream_text
return '\n'.join(result) yield RunStreamChunkEvent(
chunk_content=stream_text,
from_variable_selector=[self.node_id, 'text']
)
def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]: yield RunCompletedEvent(
result: list[dict] = [] run_result=NodeRunResult(
for message in tool_response: status=WorkflowNodeExecutionStatus.SUCCEEDED,
if message.type == ToolInvokeMessage.MessageType.JSON: outputs={
assert isinstance(message, ToolInvokeMessage.JsonMessage) 'text': text,
result.append(message.json_object) 'files': files,
return result 'json': json
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
inputs=parameters_for_log
)
)
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(

Loading…
Cancel
Save