diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 78433ccb9d..daa3ae35a9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -631,3 +631,88 @@ class TestExtractContentTypeAndExtension: content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type) assert content_type == "application/octet-stream" assert extension == ".bin" + + +class TestSaveMultimodalOutputAndConvertResultToMarkdown: + def test_str_content(self, llm_node_for_multimodal): + llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world") + assert list(gen) == ["hello world"] + mock_file_downloader.get.assert_not_called() + mock_file_saver.save_file.assert_not_called() + + def test_text_prompt_message_content(self, llm_node_for_multimodal): + llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + [TextPromptMessageContent(data="hello world")] + ) + assert list(gen) == ["hello world"] + mock_file_downloader.get.assert_not_called() + mock_file_saver.save_file.assert_not_called() + + def test_image_content(self, llm_node_for_multimodal): + llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + + image_raw_data = b"PNG_DATA" + image_b64_data = base64.b64encode(image_raw_data).decode() + + mock_saved_file = File( + id=str(uuid.uuid4()), + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + filename="test.png", + extension=".png", + size=len(image_raw_data), + related_id=str(uuid.uuid4()), + url="https://example.com/test.png", + storage_key="test_storage_key", + ) + mock_file_saver.save_file.return_value = mock_saved_file + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + [ + ImagePromptMessageContent( + format="png", + base64_data=image_b64_data, + mime_type="image/png", + ) + ] + ) + yielded_strs = list(gen) + assert len(yielded_strs) == 1 + # This assertion is somewhat tricky. + expected_file_url = f"http://127.0.0.1:5001/files/tools/{mock_saved_file.related_id}.png" + assert yielded_strs[0].startswith(f"![]({expected_file_url}") + assert yielded_strs[0].endswith(")") + mock_file_saver.save_file.assert_called_once_with( + MultiModalFile( + user_id="1", + tenant_id="1", + file_type=FileType.IMAGE, + data=image_raw_data, + mime_type="image/png", + ) + ) + mock_file_downloader.assert_not_called() + assert mock_saved_file in llm_node._file_outputs + + def test_unknown_content_type(self, llm_node_for_multimodal): + llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"])) + assert list(gen) == ["frozenset({'hello world'})"] + mock_file_downloader.get.assert_not_called() + mock_file_saver.save_file.assert_not_called() + + def test_unknown_item_type(self, llm_node_for_multimodal): + llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])]) + assert list(gen) == ["frozenset({'hello world'})"] + mock_file_downloader.get.assert_not_called() + mock_file_saver.save_file.assert_not_called() + + def test_none_content(self, llm_node_for_multimodal): + llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None) + assert list(gen) == [] + mock_file_downloader.get.assert_not_called() + mock_file_saver.save_file.assert_not_called()