test(api): Add test cases for LLMNode content convertion logic

pull/17372/head
QuantumGhost 1 year ago
parent aec7dc4d23
commit 2585cb37b3

@ -631,3 +631,88 @@ class TestExtractContentTypeAndExtension:
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
assert content_type == "application/octet-stream"
assert extension == ".bin"
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
assert list(gen) == ["hello world"]
mock_file_downloader.get.assert_not_called()
mock_file_saver.save_file.assert_not_called()
def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")]
)
assert list(gen) == ["hello world"]
mock_file_downloader.get.assert_not_called()
mock_file_saver.save_file.assert_not_called()
def test_image_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
image_raw_data = b"PNG_DATA"
image_b64_data = base64.b64encode(image_raw_data).decode()
mock_saved_file = File(
id=str(uuid.uuid4()),
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
filename="test.png",
extension=".png",
size=len(image_raw_data),
related_id=str(uuid.uuid4()),
url="https://example.com/test.png",
storage_key="test_storage_key",
)
mock_file_saver.save_file.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[
ImagePromptMessageContent(
format="png",
base64_data=image_b64_data,
mime_type="image/png",
)
]
)
yielded_strs = list(gen)
assert len(yielded_strs) == 1
# This assertion is somewhat tricky.
expected_file_url = f"http://127.0.0.1:5001/files/tools/{mock_saved_file.related_id}.png"
assert yielded_strs[0].startswith(f"![]({expected_file_url}")
assert yielded_strs[0].endswith(")")
mock_file_saver.save_file.assert_called_once_with(
MultiModalFile(
user_id="1",
tenant_id="1",
file_type=FileType.IMAGE,
data=image_raw_data,
mime_type="image/png",
)
)
mock_file_downloader.assert_not_called()
assert mock_saved_file in llm_node._file_outputs
def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_downloader.get.assert_not_called()
mock_file_saver.save_file.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_downloader.get.assert_not_called()
mock_file_saver.save_file.assert_not_called()
def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
assert list(gen) == []
mock_file_downloader.get.assert_not_called()
mock_file_saver.save_file.assert_not_called()

Loading…
Cancel
Save