From f6ab99cc24b87d0a1ef797e0f4b56180b07ce7b6 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 30 Apr 2025 17:05:23 +0800 Subject: [PATCH] refactor(api): remove the unnecessary FileDownloader interface. --- .../workflow/nodes/llm/file_downloader.py | 42 --- api/core/workflow/nodes/llm/file_saver.py | 180 ++++++----- api/core/workflow/nodes/llm/node.py | 68 +--- .../nodes/llm/test_file_downloader.py | 56 ---- .../workflow/nodes/llm/test_file_saver.py | 291 +++++++++++------- .../core/workflow/nodes/llm/test_node.py | 162 +++------- 6 files changed, 355 insertions(+), 444 deletions(-) delete mode 100644 api/core/workflow/nodes/llm/file_downloader.py delete mode 100644 api/tests/unit_tests/core/workflow/nodes/llm/test_file_downloader.py diff --git a/api/core/workflow/nodes/llm/file_downloader.py b/api/core/workflow/nodes/llm/file_downloader.py deleted file mode 100644 index ceaef28a6f..0000000000 --- a/api/core/workflow/nodes/llm/file_downloader.py +++ /dev/null @@ -1,42 +0,0 @@ -import abc -import typing as tp - -import httpx -from pydantic import BaseModel - -from core.helper import ssrf_proxy - - -class FileDownloadError(Exception): - pass - - -class HTTPStatusError(FileDownloadError): - def __init__(self, message: str, *, status_code: int): - self.status_code = status_code - - -class Response(BaseModel, frozen=True): - body: bytes - content_type: str | None = None - - -class FileDownloader(tp.Protocol): - @abc.abstractmethod - def get(self, url) -> Response: - pass - - -class SSRFProxyFileDownloader(FileDownloader): - def get(self, url) -> Response: - try: - http_response = ssrf_proxy.get(url) - http_response.raise_for_status() - return Response( - body=http_response.content, - content_type=http_response.headers.get("Content-Type"), - ) - except httpx.TimeoutException as e: - raise FileDownloadError(f"timeout when downloading file from {url}") from e - except httpx.HTTPStatusError as e: - raise HTTPStatusError(f"Error when downloading file from {url}", status_code=e.response.status_code) from e diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/core/workflow/nodes/llm/file_saver.py index ba140440c6..c85baade03 100644 --- a/api/core/workflow/nodes/llm/file_saver.py +++ b/api/core/workflow/nodes/llm/file_saver.py @@ -1,83 +1,73 @@ -import abc import mimetypes import typing as tp -from pydantic import BaseModel, field_validator from sqlalchemy import Engine -from constants.mimetypes import DEFAULT_EXTENSION +from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE from core.file import File, FileTransferMethod, FileType +from core.helper import ssrf_proxy from core.tools.signature import sign_tool_file from core.tools.tool_file_manager import ToolFileManager from models import db as global_db -class MultiModalFile(BaseModel): - # user_id records t - user_id: str - # tenant_id - tenant_id: str - file_type: FileType - - # data is the contents of the file - data: bytes - - # mime_type is the media type of the file, specified by - # rfc6838 (https://datatracker.ietf.org/doc/html/rfc6838) - mime_type: str - - # `extension_override` allow the user to manually specify the file extension to use - # while saving this file. - # - # The default value is `None`, which means do not override the file extension and guessing it - # from the `mime_type` attribute while saving the file. - # - # Setting it to values other than `None` means override the file's extension, and - # will bypass the extension guessing when calling `get_extension`. - # Specially, setting it to empty string (`""`) will leave the file extension empty. - # - # When it is not `None` or empty string (`""`), it should be a string beginning with a - # dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py` - # and `tar.gz` is not. - # - # Users of MultiModalFile should always use `get_extension` to access - # the files extension, instead reading this property directly. - extension_override: str | None = None - - def get_extension(self) -> str: - """get_extension return the extension of file. - - If the `extension_override` parameter is set, this method should honor it and - return its value. +class LLMFileSaver(tp.Protocol): + """LLMFileSaver is responsible for save multimodal output returned by + LLM. + """ + + def save_binary_string( + self, + data: bytes, + mime_type: str, + file_type: FileType, + extension_override: str | None = None, + ) -> File: + """save_binary_string saves the inline file data returned by LLM. + + Currently (2025-04-30), only some of Google Gemini models will return + multimodal output as inline data. + + :param data: the contents of the file + :param mime_type: the media type of the file, specified by rfc6838 + (https://datatracker.ietf.org/doc/html/rfc6838) + :param file_type: The file type of the inline file. + :param extension_override: Override the auto-detected file extension while saving this file. + + The default value is `None`, which means do not override the file extension and guessing it + from the `mime_type` attribute while saving the file. + + Setting it to values other than `None` means override the file's extension, and + will bypass the extension guessing saving the file. + + Specially, setting it to empty string (`""`) will leave the file extension empty. + + When it is not `None` or empty string (`""`), it should be a string beginning with a + dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py` + and `tar.gz` are not. """ - if (extension := self.extension_override) is not None: - return extension - return mimetypes.guess_extension(self.mime_type) or DEFAULT_EXTENSION - - @field_validator("extension_override") - @classmethod - def _validate_extension_override(cls, extension_override: str | None) -> str | None: - # `extension_override` is allow to be `None or `""`. - if not extension_override: - return None - if not extension_override.startswith("."): - raise ValueError("extension_override should start with '.' if not None or empty.", extension_override) - return extension_override + pass + + def save_remote_url(self, url: str, file_type: FileType) -> File: + """save_remote_url saves the file from a remote url returned by LLM. + Currently (2025-04-30), no model returns multimodel output as a url. -class MultiModalFileSaver(tp.Protocol): - @abc.abstractmethod - def save_file(self, mmf: MultiModalFile) -> File: + :param url: the url of the file. + :param file_type: the file type of the file, check `FileType` enum for reference. + """ pass EngineFactory: tp.TypeAlias = tp.Callable[[], Engine] -class StorageFileSaver(MultiModalFileSaver): +class FileSaverImpl(LLMFileSaver): _engine_factory: EngineFactory + _tenant_id: str + _user_id: str - def __init__(self, engine_factory: EngineFactory | None = None): + def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None): if engine_factory is None: def _factory(): @@ -85,30 +75,48 @@ class StorageFileSaver(MultiModalFileSaver): engine_factory = _factory self._engine_factory = engine_factory + self._user_id = user_id + self._tenant_id = tenant_id def _get_tool_file_manager(self): return ToolFileManager(engine=self._engine_factory()) - def save_file(self, mmf: MultiModalFile) -> File: + def save_remote_url(self, url: str, file_type: FileType) -> File: + http_response = ssrf_proxy.get(url) + http_response.raise_for_status() + data = http_response.content + mime_type_from_header = http_response.headers.get("Content-Type") + mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header) + return self.save_binary_string(data, mime_type, file_type, extension_override=extension) + + def save_binary_string( + self, + data: bytes, + mime_type: str, + file_type: FileType, + extension_override: str | None = None, + ) -> File: tool_file_manager = self._get_tool_file_manager() tool_file = tool_file_manager.create_file_by_raw( - user_id=mmf.user_id, - tenant_id=mmf.tenant_id, + user_id=self._user_id, + tenant_id=self._tenant_id, # TODO(QuantumGhost): what is conversation id? conversation_id=None, - file_binary=mmf.data, - mimetype=mmf.mime_type, + file_binary=data, + mimetype=mime_type, ) - url = sign_tool_file(tool_file.id, mmf.get_extension()) + extension_override = _validate_extension_override(extension_override) + extension = _get_extension(mime_type, extension_override) + url = sign_tool_file(tool_file.id, extension) return File( - tenant_id=mmf.tenant_id, - type=FileType.IMAGE, + tenant_id=self._tenant_id, + type=file_type, transfer_method=FileTransferMethod.TOOL_FILE, filename=tool_file.name, - extension=mmf.get_extension(), - mime_type=mmf.mime_type, - size=len(mmf.data), + extension=extension, + mime_type=mime_type, + size=len(data), related_id=tool_file.id, url=url, # TODO(QuantumGhost): how should I set the following key? @@ -116,3 +124,37 @@ class StorageFileSaver(MultiModalFileSaver): # What's the purpose of `storage_key` and `dify_model_identity`? storage_key=tool_file.file_key, ) + + +def _get_extension(mime_type: str, extension_override: str | None = None) -> str: + """get_extension return the extension of file. + + If the `extension_override` parameter is set, this function should honor it and + return its value. + """ + if extension_override is not None: + return extension_override + return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION + + +def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]: + """_extract_content_type_and_extension tries to + guess content type of file from url and `Content-Type` header in response. + """ + if content_type_header: + extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION + return content_type_header, extension + content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE + extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION + return content_type, extension + + +def _validate_extension_override(extension_override: str | None) -> str | None: + # `extension_override` is allow to be `None or `""`. + if extension_override is None: + return None + if extension_override == "": + return "" + if not extension_override.startswith("."): + raise ValueError("extension_override should start with '.' if not None or empty.", extension_override) + return extension_override diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index b3efa2c211..5481bd383a 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -2,7 +2,6 @@ import base64 import io import json import logging -import mimetypes from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Optional, cast @@ -10,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Optional, cast import json_repair from configs import dify_config -from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit @@ -98,8 +96,7 @@ from .exc import ( TemplateTypeNotSupportError, VariableNotFoundError, ) -from .file_downloader import FileDownloader, SSRFProxyFileDownloader -from .file_saver import MultiModalFile, MultiModalFileSaver, StorageFileSaver +from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from core.file.models import File @@ -117,8 +114,8 @@ class LLMNode(BaseNode[LLMNodeData]): # Instance attributes specific to LLMNode. # Output variable for file _file_outputs: list["File"] - _file_downloader: FileDownloader - _multi_modal_file_saver: MultiModalFileSaver + + _llm_file_saver: LLMFileSaver def __init__( self, @@ -130,8 +127,7 @@ class LLMNode(BaseNode[LLMNodeData]): previous_node_id: Optional[str] = None, thread_pool_id: Optional[str] = None, *, - file_downloader: FileDownloader | None = None, - multi_modal_file_saver: MultiModalFileSaver | None = None, + llm_file_saver: LLMFileSaver | None = None, ) -> None: super().__init__( id=id, @@ -144,12 +140,13 @@ class LLMNode(BaseNode[LLMNodeData]): ) # LLM file outputs, used for MultiModal outputs. self._file_outputs: list[File] = [] - if file_downloader is None: - file_downloader = SSRFProxyFileDownloader() - self._file_downloader = file_downloader - if multi_modal_file_saver is None: - multi_modal_file_saver = StorageFileSaver() - self._multi_modal_file_saver = multi_modal_file_saver + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]: @@ -1035,40 +1032,19 @@ class LLMNode(BaseNode[LLMNodeData]): Currently, only image files are supported. """ # Inject the saver somehow... - _saver = self._multi_modal_file_saver + _saver = self._llm_file_saver # If this if content.url != "": - mmf = self._download_file(content.url, FileType.IMAGE) - saved_file = _saver.save_file(mmf) - self._file_outputs.append(saved_file) - return saved_file + saved_file = _saver.save_remote_url(content.url, FileType.IMAGE) else: - mmf = MultiModalFile( - user_id=self.user_id, - tenant_id=self.tenant_id, + saved_file = _saver.save_binary_string( data=base64.b64decode(content.base64_data), mime_type=content.mime_type, file_type=FileType.IMAGE, - extension_override=None, ) - saved_file = _saver.save_file(mmf) - self._file_outputs.append(saved_file) - return saved_file - - def _download_file(self, url: str, file_type: FileType) -> MultiModalFile: - downloader = self._file_downloader - # try to download image - response = downloader.get(url) - content_type, extension = _extract_content_type_and_extension(url, response.content_type) - return MultiModalFile( - user_id=self.user_id, - tenant_id=self.tenant_id, - file_type=file_type, - data=response.body, - mime_type=content_type, - extension_override=extension, - ) + self._file_outputs.append(saved_file) + return saved_file def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: """ @@ -1451,15 +1427,3 @@ def convert_boolean_to_string(schema: dict) -> None: for item in value: if isinstance(item, dict): convert_boolean_to_string(item) - - -def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]: - """_extract_content_type_and_extension tries to - guess content type of file from url and `Content-Type` header in response. - """ - if content_type_header: - extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION - return content_type_header, extension - content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE - extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION - return content_type, extension diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_downloader.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_downloader.py deleted file mode 100644 index e09263f371..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_downloader.py +++ /dev/null @@ -1,56 +0,0 @@ -from unittest import mock - -import httpx -import pytest - -from core.workflow.nodes.llm.file_downloader import ( - FileDownloadError, - HTTPStatusError, - Response, - SSRFProxyFileDownloader, - ssrf_proxy, -) - -_TEST_URL = "https://example.com" - - -class TestSSRFProxyFileDownloader: - def test(self, monkeypatch): - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=200, - content=b"test-data", - headers={"Content-Type": "text/plain"}, - request=mock_request, - ) - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) - - downloader = SSRFProxyFileDownloader() - response = downloader.get(_TEST_URL) - mock_get.assert_called_once_with(_TEST_URL) - assert response == Response(body=mock_response.content, content_type="text/plain") - - def test_should_raise_when_status_is_not_successful(self, monkeypatch): - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=401, - request=mock_request, - ) - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) - - downloader = SSRFProxyFileDownloader() - with pytest.raises(HTTPStatusError) as exc: - response = downloader.get(_TEST_URL) - mock_get.assert_called_once_with(_TEST_URL) - assert exc.value.status_code == 401 - - def test_should_convert_timeout_to_file_download_error(self, monkeypatch): - mock_get = mock.MagicMock(spec=ssrf_proxy.get, side_effect=httpx.TimeoutException("timeout")) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) - - downloader = SSRFProxyFileDownloader() - with pytest.raises(FileDownloadError) as exc: - response = downloader.get(_TEST_URL) - mock_get.assert_called_once_with(_TEST_URL) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index 732ca98449..7c722660bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -1,129 +1,192 @@ import uuid +from typing import NamedTuple from unittest import mock -import pydantic +import httpx import pytest from sqlalchemy import Engine from core.file import FileTransferMethod, FileType, models +from core.helper import ssrf_proxy from core.tools import signature from core.tools.tool_file_manager import ToolFileManager -from core.workflow.nodes.llm.file_saver import MultiModalFile, StorageFileSaver +from core.workflow.nodes.llm.file_saver import ( + FileSaverImpl, + _extract_content_type_and_extension, + _get_extension, + _validate_extension_override, +) from models import ToolFile _PNG_DATA = b"\x89PNG\r\n\x1a\n" -# -# class _MockToolFileManager: -# def __init__(self, mock_tool_file: ToolFile, mock_signed_url: str): -# self._mock_tool_file = mock_tool_file -# self._mock_signed_url = mock_signed_url -# -# @staticmethod -# def create_file_by_raw( -# *, -# user_id: str, -# tenant_id: str, -# conversation_id: Optional[str], -# file_binary: bytes, -# mimetype: str, -# filename: Optional[str] = None, -# ) -> ToolFile: -# return self._mock_tool_file -# -# @staticmethod -# def sign_file(tool_file_id: str, extension: str) -> str: -# return "" - - -def test_storage_file_saver(monkeypatch): - def gen_id(): - return str(uuid.uuid4()) - - mmf = MultiModalFile( - user_id=gen_id(), - tenant_id=gen_id(), - file_type=FileType.IMAGE, - data=_PNG_DATA, - mime_type="image/png", - extension_override=None, - ) - mock_signed_url = "https://example.com/image.png" - mock_tool_file = ToolFile( - id=gen_id(), - user_id=mmf.user_id, - tenant_id=mmf.tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mmf.mime_type, - original_url=None, - name=f"{gen_id()}.png", - size=len(mmf.data), - ) - mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - mocked_engine = mock.MagicMock(spec=Engine) - - mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file - monkeypatch.setattr(StorageFileSaver, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) - # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. - mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) - # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. - monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) - mocked_sign_file.return_value = mock_signed_url - - storage_file_manager = StorageFileSaver(engine_factory=lambda: mocked_engine) - - file = storage_file_manager.save_file(mmf) - assert file.tenant_id == mmf.tenant_id - assert file.type == mmf.file_type - assert file.transfer_method == FileTransferMethod.TOOL_FILE - assert file.extension == mmf.get_extension() - assert file.mime_type == mmf.mime_type - assert file.size == len(mmf.data) - assert file.related_id == mock_tool_file.id - - assert file.generate_url() == mock_signed_url - - mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( - user_id=mmf.user_id, - tenant_id=mmf.tenant_id, - conversation_id=None, - file_binary=mmf.data, - mimetype=mmf.mime_type, - ) - mocked_sign_file.assert_called_once_with(mock_tool_file.id, mmf.get_extension()) - - -def test_multi_modal_file_extension_override(): - # Test should pass if `extension_override` is not set. - MultiModalFile(user_id="", tenant_id="", file_type=FileType.IMAGE, data=b"", mime_type="image/png") - - # Test should pass if `extension_override` is explicitly set to `None`. - MultiModalFile( - user_id="", tenant_id="", file_type=FileType.IMAGE, data=b"", mime_type="image/png", extension_override=None - ) - - # Test should pass if `extension_override` is a string prefixed with `.`. - for extension_override in [".png", ".tar.gz"]: - MultiModalFile( - user_id="", - tenant_id="", - file_type=FileType.IMAGE, - data=b"", - mime_type="image/png", - extension_override=extension_override, + + +def _gen_id(): + return str(uuid.uuid4()) + + +class TestFileSaverImpl: + def test_save_binary_string(self, monkeypatch): + user_id = _gen_id() + tenant_id = _gen_id() + file_type = FileType.IMAGE + mime_type = "image/png" + mock_signed_url = "https://example.com/image.png" + mock_tool_file = ToolFile( + id=_gen_id(), + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_key="test-file-key", + mimetype=mime_type, + original_url=None, + name=f"{_gen_id()}.png", + size=len(_PNG_DATA), + ) + mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) + mocked_engine = mock.MagicMock(spec=Engine) + + mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file + monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) + # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. + mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) + # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. + monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) + mocked_sign_file.return_value = mock_signed_url + + storage_file_manager = FileSaverImpl( + user_id=user_id, + tenant_id=tenant_id, + engine_factory=mocked_engine, + ) + + file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) + assert file.tenant_id == tenant_id + assert file.type == file_type + assert file.transfer_method == FileTransferMethod.TOOL_FILE + assert file.extension == ".png" + assert file.mime_type == mime_type + assert file.size == len(_PNG_DATA) + assert file.related_id == mock_tool_file.id + + assert file.generate_url() == mock_signed_url + + mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=_PNG_DATA, + mimetype=mime_type, + ) + mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png") + + def test_save_remote_url_request_failed(self, monkeypatch): + _TEST_URL = "https://example.com/image.png" + mock_request = httpx.Request("GET", _TEST_URL) + mock_response = httpx.Response( + status_code=401, + request=mock_request, + ) + file_saver = FileSaverImpl( + user_id=_gen_id(), + tenant_id=_gen_id(), + ) + mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) + monkeypatch.setattr(ssrf_proxy, "get", mock_get) + + with pytest.raises(httpx.HTTPStatusError) as exc: + file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) + mock_get.assert_called_once_with(_TEST_URL) + assert exc.value.response.status_code == 401 + + def test_save_remote_url_success(self, monkeypatch): + _TEST_URL = "https://example.com/image.png" + mime_type = "image/png" + user_id = _gen_id() + tenant_id = _gen_id() + + mock_request = httpx.Request("GET", _TEST_URL) + mock_response = httpx.Response( + status_code=200, + content=b"test-data", + headers={"Content-Type": mime_type}, + request=mock_request, ) + file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id) + mock_tool_file = ToolFile( + id=_gen_id(), + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_key="test-file-key", + mimetype=mime_type, + original_url=None, + name=f"{_gen_id()}.png", + size=len(_PNG_DATA), + ) + mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) + monkeypatch.setattr(ssrf_proxy, "get", mock_get) + mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file) + monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) + + file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) + mock_save_binary_string.assert_called_once_with( + mock_response.content, + mime_type, + FileType.IMAGE, + extension_override=".png", + ) + assert file == mock_tool_file + + +def test_validate_extension_override(): + class TestCase(NamedTuple): + extension_override: str | None + expected: str | None + + cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"] + + for valid_ext_override in [None, "", ".png", ".tar.gz"]: + assert valid_ext_override == _validate_extension_override(valid_ext_override) + for invalid_ext_override in ["png", "tar.gz"]: - with pytest.raises(pydantic.ValidationError) as exc: - MultiModalFile( - user_id="", - tenant_id="", - file_type=FileType.IMAGE, - data=b"", - mime_type="image/png", - extension_override=invalid_ext_override, - ) - - error_details = exc.value.errors() - assert exc.value.error_count() == 1 - assert error_details[0]["loc"] == ("extension_override",) + with pytest.raises(ValueError) as exc: + _validate_extension_override(invalid_ext_override) + + +class TestExtractContentTypeAndExtension: + def test_with_both_content_type_and_extension(self): + content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png") + assert content_type == "image/png" + assert extension == ".png" + + def test_url_with_file_extension(self): + for content_type in [None, ""]: + content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type) + assert content_type == "image/png" + assert extension == ".png" + + def test_response_with_content_type(self): + content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png") + assert content_type == "image/png" + assert extension == ".png" + + def test_no_content_type_and_no_extension(self): + for content_type in [None, ""]: + content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type) + assert content_type == "application/octet-stream" + assert extension == ".bin" + + +class TestGetExtension: + def test_with_extension_override(self): + mime_type = "image/png" + for override in [".jpg", ""]: + extension = _get_extension(mime_type, override) + assert extension == override + + def test_without_extension_override(self): + mime_type = "image/png" + extension = _get_extension(mime_type) + assert extension == ".png" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index daa3ae35a9..f92ef71595 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,10 +1,10 @@ +from unittest import mock + import base64 +import pytest import uuid from collections.abc import Sequence from typing import Optional -from unittest import mock - -import pytest from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle @@ -33,9 +33,8 @@ from core.workflow.nodes.llm.entities import ( VisionConfig, VisionConfigOptions, ) -from core.workflow.nodes.llm.file_downloader import FileDownloader, Response -from core.workflow.nodes.llm.file_saver import MultiModalFile, MultiModalFileSaver -from core.workflow.nodes.llm.node import LLMNode, _extract_content_type_and_extension +from core.workflow.nodes.llm.file_saver import LLMFileSaver +from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom from models.provider import ProviderType from models.workflow import WorkflowType @@ -117,8 +116,7 @@ def graph_runtime_state() -> GraphRuntimeState: def llm_node( llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState ) -> LLMNode: - mock_file_saver = mock.MagicMock(spec=MultiModalFileSaver) - mock_file_downloader = mock.MagicMock(spec=FileDownloader) + mock_file_saver = mock.MagicMock(spec=LLMFileSaver) node = LLMNode( id="1", config={ @@ -128,8 +126,7 @@ def llm_node( graph_init_params=graph_init_params, graph=graph, graph_runtime_state=graph_runtime_state, - file_downloader=mock_file_downloader, - multi_modal_file_saver=mock_file_saver, + llm_file_saver=mock_file_saver, ) return node @@ -500,9 +497,8 @@ def test_handle_list_messages_basic(llm_node): @pytest.fixture def llm_node_for_multimodal( llm_node_data, graph_init_params, graph, graph_runtime_state -) -> tuple[LLMNode, FileDownloader, MultiModalFileSaver]: - mock_file_downloader: FileDownloader = mock.MagicMock(spec=FileDownloader) - mock_file_saver: MultiModalFileSaver = mock.MagicMock(spec=MultiModalFileSaver) +) -> tuple[LLMNode, LLMFileSaver]: + mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) node = LLMNode( id="1", config={ @@ -512,29 +508,14 @@ def llm_node_for_multimodal( graph_init_params=graph_init_params, graph=graph, graph_runtime_state=graph_runtime_state, - file_downloader=mock_file_downloader, - multi_modal_file_saver=mock_file_saver, + llm_file_saver=mock_file_saver, ) - return node, mock_file_downloader, mock_file_saver + return node, mock_file_saver class TestLLMNodeSaveMultiModalImageOutput: - def test_llm_node_download_file(self, llm_node_for_multimodal): - llm_node, mock_file_downloader, _ = llm_node_for_multimodal - mock_response = Response(body=b"test-data", content_type="image/png") - mock_file_downloader.get.return_value = mock_response - file = llm_node._download_file("https://example.com/image.png", FileType.IMAGE) - assert file.user_id == "1" - assert file.tenant_id == "1" - assert file.file_type == FileType.IMAGE - assert file.mime_type == mock_response.content_type - assert file.data == mock_response.body - assert file.get_extension() == ".png" - - def test_llm_node_save_inline_output( - self, llm_node_for_multimodal: tuple[LLMNode, FileDownloader, MultiModalFileSaver] - ): - llm_node, _, mock_file_saver = llm_node_for_multimodal + def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]): + llm_node, mock_file_saver = llm_node_for_multimodal content = ImagePromptMessageContent( format="png", base64_data=base64.b64encode(b"test-data").decode(), @@ -551,24 +532,16 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/png", size=9, ) - mock_file_saver.save_file.return_value = mock_file + mock_file_saver.save_binary_string.return_value = mock_file file = llm_node._save_multimodal_image_output(content=content) assert llm_node._file_outputs == [mock_file] assert file == mock_file - expected_saved_multimodal_file = MultiModalFile( - user_id="1", - tenant_id="1", - file_type=FileType.IMAGE, - data=b"test-data", - mime_type="image/png", - extension_override=None, + mock_file_saver.save_binary_string.assert_called_once_with( + data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE ) - mock_file_saver.save_file.assert_called_once_with(expected_saved_multimodal_file) - def test_llm_node_save_url_output( - self, llm_node_for_multimodal: tuple[LLMNode, FileDownloader, MultiModalFileSaver] - ): - llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]): + llm_node, mock_file_saver = llm_node_for_multimodal content = ImagePromptMessageContent( format="png", url="https://example.com/image.png", @@ -585,21 +558,11 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/png", size=9, ) - mock_file_downloader.get.return_value = Response(body=b"test-data", content_type="image/png") - mock_file_saver.save_file.return_value = mock_file + mock_file_saver.save_remote_url.return_value = mock_file file = llm_node._save_multimodal_image_output(content=content) assert llm_node._file_outputs == [mock_file] assert file == mock_file - expected_saved_multimodal_file = MultiModalFile( - user_id="1", - tenant_id="1", - file_type=FileType.IMAGE, - data=b"test-data", - mime_type="image/png", - extension_override=".png", - ) - mock_file_downloader.get.assert_called_once_with(content.url) - mock_file_saver.save_file.assert_called_once_with(expected_saved_multimodal_file) + mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE) def test_llm_node_image_file_to_markdown(llm_node: LLMNode): @@ -609,49 +572,25 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode): assert markdown == "![](https://example.com/image.png)" -class TestExtractContentTypeAndExtension: - def test_with_both_content_type_and_extension(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_url_with_file_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type) - assert content_type == "image/png" - assert extension == ".png" - - def test_response_with_content_type(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_no_content_type_and_no_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type) - assert content_type == "application/octet-stream" - assert extension == ".bin" - - class TestSaveMultimodalOutputAndConvertResultToMarkdown: def test_str_content(self, llm_node_for_multimodal): - llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + llm_node, mock_file_saver = llm_node_for_multimodal gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world") assert list(gen) == ["hello world"] - mock_file_downloader.get.assert_not_called() - mock_file_saver.save_file.assert_not_called() + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() def test_text_prompt_message_content(self, llm_node_for_multimodal): - llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + llm_node, mock_file_saver = llm_node_for_multimodal gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( [TextPromptMessageContent(data="hello world")] ) assert list(gen) == ["hello world"] - mock_file_downloader.get.assert_not_called() - mock_file_saver.save_file.assert_not_called() + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() - def test_image_content(self, llm_node_for_multimodal): - llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch): + llm_node, mock_file_saver = llm_node_for_multimodal image_raw_data = b"PNG_DATA" image_b64_data = base64.b64encode(image_raw_data).decode() @@ -668,7 +607,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: url="https://example.com/test.png", storage_key="test_storage_key", ) - mock_file_saver.save_file.return_value = mock_saved_file + mock_file_saver.save_binary_string.return_value = mock_saved_file gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( [ ImagePromptMessageContent( @@ -680,39 +619,40 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: ) yielded_strs = list(gen) assert len(yielded_strs) == 1 - # This assertion is somewhat tricky. - expected_file_url = f"http://127.0.0.1:5001/files/tools/{mock_saved_file.related_id}.png" - assert yielded_strs[0].startswith(f"![]({expected_file_url}") + + # This assertion requires careful handling. + # `FILES_URL` settings can vary across environments, which might lead to fragile tests. + # + # Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown, + # we verify that the result includes the markdown image syntax and the expected file URL path. + expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png" + assert yielded_strs[0].startswith(f"![](") + assert expected_file_url_path in yielded_strs[0] assert yielded_strs[0].endswith(")") - mock_file_saver.save_file.assert_called_once_with( - MultiModalFile( - user_id="1", - tenant_id="1", - file_type=FileType.IMAGE, - data=image_raw_data, - mime_type="image/png", - ) + mock_file_saver.save_binary_string.assert_called_once_with( + data=image_raw_data, + mime_type="image/png", + file_type=FileType.IMAGE, ) - mock_file_downloader.assert_not_called() assert mock_saved_file in llm_node._file_outputs def test_unknown_content_type(self, llm_node_for_multimodal): - llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + llm_node, mock_file_saver = llm_node_for_multimodal gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"])) assert list(gen) == ["frozenset({'hello world'})"] - mock_file_downloader.get.assert_not_called() - mock_file_saver.save_file.assert_not_called() + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() def test_unknown_item_type(self, llm_node_for_multimodal): - llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + llm_node, mock_file_saver = llm_node_for_multimodal gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])]) assert list(gen) == ["frozenset({'hello world'})"] - mock_file_downloader.get.assert_not_called() - mock_file_saver.save_file.assert_not_called() + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() def test_none_content(self, llm_node_for_multimodal): - llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal + llm_node, mock_file_saver = llm_node_for_multimodal gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None) assert list(gen) == [] - mock_file_downloader.get.assert_not_called() - mock_file_saver.save_file.assert_not_called() + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called()