fix: tool node

pull/9184/head
Yeuoly 2 years ago
parent c28998a6f0
commit 531ffaec4f
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -128,8 +128,10 @@ class ToolNode(BaseNode):
else: else:
tool_input = node_data.tool_parameters[parameter_name] tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == 'variable': if tool_input.type == 'variable':
# TODO: check if the variable exists in the variable pool parameter_value_segment = variable_pool.get(tool_input.value)
parameter_value = variable_pool.get(tool_input.value).value if not parameter_value_segment:
raise Exception("input variable dose not exists")
parameter_value = parameter_value_segment.value
else: else:
segment_group = parser.convert_template( segment_group = parser.convert_template(
template=str(tool_input.value), template=str(tool_input.value),
@ -163,7 +165,7 @@ class ToolNode(BaseNode):
return plain_text, files, json return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]: def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]:
""" """
Extract tool response binary Extract tool response binary
""" """
@ -172,7 +174,10 @@ class ToolNode(BaseNode):
for response in tool_response: for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE: response.type == ToolInvokeMessage.MessageType.IMAGE:
url = response.message assert isinstance(response.message, ToolInvokeMessage.TextMessage)
assert response.meta
url = response.message.text
ext = path.splitext(url)[1] ext = path.splitext(url)[1]
mimetype = response.meta.get('mime_type', 'image/jpeg') mimetype = response.meta.get('mime_type', 'image/jpeg')
filename = response.save_as or url.split('/')[-1] filename = response.save_as or url.split('/')[-1]
@ -192,7 +197,10 @@ class ToolNode(BaseNode):
)) ))
elif response.type == ToolInvokeMessage.MessageType.BLOB: elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id # get tool file id
tool_file_id = response.message.split('/')[-1].split('.')[0] assert isinstance(response.message, ToolInvokeMessage.TextMessage)
assert response.meta
tool_file_id = response.message.text.split('/')[-1].split('.')[0]
result.append(FileVar( result.append(FileVar(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=FileType.IMAGE, type=FileType.IMAGE,
@ -207,18 +215,28 @@ class ToolNode(BaseNode):
return result return result
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str: def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str:
""" """
Extract tool response text Extract tool response text
""" """
return '\n'.join([ result: list[str] = []
f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else for message in tool_response:
f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else '' if message.type == ToolInvokeMessage.MessageType.TEXT:
for message in tool_response assert isinstance(message.message, ToolInvokeMessage.TextMessage)
]) result.append(message.message.text)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
result.append(f'Link: {message.message.text}')
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: return '\n'.join(result)
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]:
result: list[dict] = []
for message in tool_response:
if message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message, ToolInvokeMessage.JsonMessage)
result.append(message.json_object)
return result
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
@ -231,6 +249,7 @@ class ToolNode(BaseNode):
for parameter_name in node_data.tool_parameters: for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name] input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed': if input.type == 'mixed':
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors() selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors: for selector in selectors:
result[selector.variable] = selector.value_selector result[selector.variable] = selector.value_selector

Loading…
Cancel
Save