refactor(api): remove the unnecessary FileDownloader interface.

pull/17372/head
QuantumGhost 1 year ago
parent 2585cb37b3
commit f6ab99cc24

@ -1,42 +0,0 @@
import abc
import typing as tp
import httpx
from pydantic import BaseModel
from core.helper import ssrf_proxy
class FileDownloadError(Exception):
pass
class HTTPStatusError(FileDownloadError):
def __init__(self, message: str, *, status_code: int):
self.status_code = status_code
class Response(BaseModel, frozen=True):
body: bytes
content_type: str | None = None
class FileDownloader(tp.Protocol):
@abc.abstractmethod
def get(self, url) -> Response:
pass
class SSRFProxyFileDownloader(FileDownloader):
def get(self, url) -> Response:
try:
http_response = ssrf_proxy.get(url)
http_response.raise_for_status()
return Response(
body=http_response.content,
content_type=http_response.headers.get("Content-Type"),
)
except httpx.TimeoutException as e:
raise FileDownloadError(f"timeout when downloading file from {url}") from e
except httpx.HTTPStatusError as e:
raise HTTPStatusError(f"Error when downloading file from {url}", status_code=e.response.status_code) from e

@ -1,83 +1,73 @@
import abc
import mimetypes import mimetypes
import typing as tp import typing as tp
from pydantic import BaseModel, field_validator
from sqlalchemy import Engine from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db from models import db as global_db
class MultiModalFile(BaseModel): class LLMFileSaver(tp.Protocol):
# user_id records t """LLMFileSaver is responsible for save multimodal output returned by
user_id: str LLM.
# tenant_id """
tenant_id: str
file_type: FileType
# data is the contents of the file
data: bytes
# mime_type is the media type of the file, specified by
# rfc6838 (https://datatracker.ietf.org/doc/html/rfc6838)
mime_type: str
# `extension_override` allow the user to manually specify the file extension to use
# while saving this file.
#
# The default value is `None`, which means do not override the file extension and guessing it
# from the `mime_type` attribute while saving the file.
#
# Setting it to values other than `None` means override the file's extension, and
# will bypass the extension guessing when calling `get_extension`.
# Specially, setting it to empty string (`""`) will leave the file extension empty.
#
# When it is not `None` or empty string (`""`), it should be a string beginning with a
# dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
# and `tar.gz` is not.
#
# Users of MultiModalFile should always use `get_extension` to access
# the files extension, instead reading this property directly.
extension_override: str | None = None
def get_extension(self) -> str:
"""get_extension return the extension of file.
If the `extension_override` parameter is set, this method should honor it and def save_binary_string(
return its value. self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
"""save_binary_string saves the inline file data returned by LLM.
Currently (2025-04-30), only some of Google Gemini models will return
multimodal output as inline data.
:param data: the contents of the file
:param mime_type: the media type of the file, specified by rfc6838
(https://datatracker.ietf.org/doc/html/rfc6838)
:param file_type: The file type of the inline file.
:param extension_override: Override the auto-detected file extension while saving this file.
The default value is `None`, which means do not override the file extension and guessing it
from the `mime_type` attribute while saving the file.
Setting it to values other than `None` means override the file's extension, and
will bypass the extension guessing saving the file.
Specially, setting it to empty string (`""`) will leave the file extension empty.
When it is not `None` or empty string (`""`), it should be a string beginning with a
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
and `tar.gz` are not.
""" """
if (extension := self.extension_override) is not None: pass
return extension
return mimetypes.guess_extension(self.mime_type) or DEFAULT_EXTENSION
@field_validator("extension_override") def save_remote_url(self, url: str, file_type: FileType) -> File:
@classmethod """save_remote_url saves the file from a remote url returned by LLM.
def _validate_extension_override(cls, extension_override: str | None) -> str | None:
# `extension_override` is allow to be `None or `""`.
if not extension_override:
return None
if not extension_override.startswith("."):
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
return extension_override
Currently (2025-04-30), no model returns multimodel output as a url.
class MultiModalFileSaver(tp.Protocol): :param url: the url of the file.
@abc.abstractmethod :param file_type: the file type of the file, check `FileType` enum for reference.
def save_file(self, mmf: MultiModalFile) -> File: """
pass pass
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine] EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class StorageFileSaver(MultiModalFileSaver): class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory _engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, engine_factory: EngineFactory | None = None): def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None: if engine_factory is None:
def _factory(): def _factory():
@ -85,30 +75,48 @@ class StorageFileSaver(MultiModalFileSaver):
engine_factory = _factory engine_factory = _factory
self._engine_factory = engine_factory self._engine_factory = engine_factory
self._user_id = user_id
self._tenant_id = tenant_id
def _get_tool_file_manager(self): def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory()) return ToolFileManager(engine=self._engine_factory())
def save_file(self, mmf: MultiModalFile) -> File: def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
tool_file_manager = self._get_tool_file_manager() tool_file_manager = self._get_tool_file_manager()
tool_file = tool_file_manager.create_file_by_raw( tool_file = tool_file_manager.create_file_by_raw(
user_id=mmf.user_id, user_id=self._user_id,
tenant_id=mmf.tenant_id, tenant_id=self._tenant_id,
# TODO(QuantumGhost): what is conversation id? # TODO(QuantumGhost): what is conversation id?
conversation_id=None, conversation_id=None,
file_binary=mmf.data, file_binary=data,
mimetype=mmf.mime_type, mimetype=mime_type,
) )
url = sign_tool_file(tool_file.id, mmf.get_extension()) extension_override = _validate_extension_override(extension_override)
extension = _get_extension(mime_type, extension_override)
url = sign_tool_file(tool_file.id, extension)
return File( return File(
tenant_id=mmf.tenant_id, tenant_id=self._tenant_id,
type=FileType.IMAGE, type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE, transfer_method=FileTransferMethod.TOOL_FILE,
filename=tool_file.name, filename=tool_file.name,
extension=mmf.get_extension(), extension=extension,
mime_type=mmf.mime_type, mime_type=mime_type,
size=len(mmf.data), size=len(data),
related_id=tool_file.id, related_id=tool_file.id,
url=url, url=url,
# TODO(QuantumGhost): how should I set the following key? # TODO(QuantumGhost): how should I set the following key?
@ -116,3 +124,37 @@ class StorageFileSaver(MultiModalFileSaver):
# What's the purpose of `storage_key` and `dify_model_identity`? # What's the purpose of `storage_key` and `dify_model_identity`?
storage_key=tool_file.file_key, storage_key=tool_file.file_key,
) )
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
"""get_extension return the extension of file.
If the `extension_override` parameter is set, this function should honor it and
return its value.
"""
if extension_override is not None:
return extension_override
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
"""_extract_content_type_and_extension tries to
guess content type of file from url and `Content-Type` header in response.
"""
if content_type_header:
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
return content_type_header, extension
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
return content_type, extension
def _validate_extension_override(extension_override: str | None) -> str | None:
# `extension_override` is allow to be `None or `""`.
if extension_override is None:
return None
if extension_override == "":
return ""
if not extension_override.startswith("."):
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
return extension_override

@ -2,7 +2,6 @@ import base64
import io import io
import json import json
import logging import logging
import mimetypes
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
@ -10,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import json_repair import json_repair
from configs import dify_config from configs import dify_config
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
@ -98,8 +96,7 @@ from .exc import (
TemplateTypeNotSupportError, TemplateTypeNotSupportError,
VariableNotFoundError, VariableNotFoundError,
) )
from .file_downloader import FileDownloader, SSRFProxyFileDownloader from .file_saver import FileSaverImpl, LLMFileSaver
from .file_saver import MultiModalFile, MultiModalFileSaver, StorageFileSaver
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
@ -117,8 +114,8 @@ class LLMNode(BaseNode[LLMNodeData]):
# Instance attributes specific to LLMNode. # Instance attributes specific to LLMNode.
# Output variable for file # Output variable for file
_file_outputs: list["File"] _file_outputs: list["File"]
_file_downloader: FileDownloader
_multi_modal_file_saver: MultiModalFileSaver _llm_file_saver: LLMFileSaver
def __init__( def __init__(
self, self,
@ -130,8 +127,7 @@ class LLMNode(BaseNode[LLMNodeData]):
previous_node_id: Optional[str] = None, previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None, thread_pool_id: Optional[str] = None,
*, *,
file_downloader: FileDownloader | None = None, llm_file_saver: LLMFileSaver | None = None,
multi_modal_file_saver: MultiModalFileSaver | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
id=id, id=id,
@ -144,12 +140,13 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
# LLM file outputs, used for MultiModal outputs. # LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = [] self._file_outputs: list[File] = []
if file_downloader is None:
file_downloader = SSRFProxyFileDownloader() if llm_file_saver is None:
self._file_downloader = file_downloader llm_file_saver = FileSaverImpl(
if multi_modal_file_saver is None: user_id=graph_init_params.user_id,
multi_modal_file_saver = StorageFileSaver() tenant_id=graph_init_params.tenant_id,
self._multi_modal_file_saver = multi_modal_file_saver )
self._llm_file_saver = llm_file_saver
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]: def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
@ -1035,41 +1032,20 @@ class LLMNode(BaseNode[LLMNodeData]):
Currently, only image files are supported. Currently, only image files are supported.
""" """
# Inject the saver somehow... # Inject the saver somehow...
_saver = self._multi_modal_file_saver _saver = self._llm_file_saver
# If this # If this
if content.url != "": if content.url != "":
mmf = self._download_file(content.url, FileType.IMAGE) saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
saved_file = _saver.save_file(mmf)
self._file_outputs.append(saved_file)
return saved_file
else: else:
mmf = MultiModalFile( saved_file = _saver.save_binary_string(
user_id=self.user_id,
tenant_id=self.tenant_id,
data=base64.b64decode(content.base64_data), data=base64.b64decode(content.base64_data),
mime_type=content.mime_type, mime_type=content.mime_type,
file_type=FileType.IMAGE, file_type=FileType.IMAGE,
extension_override=None,
) )
saved_file = _saver.save_file(mmf)
self._file_outputs.append(saved_file) self._file_outputs.append(saved_file)
return saved_file return saved_file
def _download_file(self, url: str, file_type: FileType) -> MultiModalFile:
downloader = self._file_downloader
# try to download image
response = downloader.get(url)
content_type, extension = _extract_content_type_and_extension(url, response.content_type)
return MultiModalFile(
user_id=self.user_id,
tenant_id=self.tenant_id,
file_type=file_type,
data=response.body,
mime_type=content_type,
extension_override=extension,
)
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
""" """
Handle structured output for models with native JSON schema support. Handle structured output for models with native JSON schema support.
@ -1451,15 +1427,3 @@ def convert_boolean_to_string(schema: dict) -> None:
for item in value: for item in value:
if isinstance(item, dict): if isinstance(item, dict):
convert_boolean_to_string(item) convert_boolean_to_string(item)
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
"""_extract_content_type_and_extension tries to
guess content type of file from url and `Content-Type` header in response.
"""
if content_type_header:
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
return content_type_header, extension
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
return content_type, extension

@ -1,56 +0,0 @@
from unittest import mock
import httpx
import pytest
from core.workflow.nodes.llm.file_downloader import (
FileDownloadError,
HTTPStatusError,
Response,
SSRFProxyFileDownloader,
ssrf_proxy,
)
_TEST_URL = "https://example.com"
class TestSSRFProxyFileDownloader:
def test(self, monkeypatch):
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=200,
content=b"test-data",
headers={"Content-Type": "text/plain"},
request=mock_request,
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
downloader = SSRFProxyFileDownloader()
response = downloader.get(_TEST_URL)
mock_get.assert_called_once_with(_TEST_URL)
assert response == Response(body=mock_response.content, content_type="text/plain")
def test_should_raise_when_status_is_not_successful(self, monkeypatch):
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=401,
request=mock_request,
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
downloader = SSRFProxyFileDownloader()
with pytest.raises(HTTPStatusError) as exc:
response = downloader.get(_TEST_URL)
mock_get.assert_called_once_with(_TEST_URL)
assert exc.value.status_code == 401
def test_should_convert_timeout_to_file_download_error(self, monkeypatch):
mock_get = mock.MagicMock(spec=ssrf_proxy.get, side_effect=httpx.TimeoutException("timeout"))
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
downloader = SSRFProxyFileDownloader()
with pytest.raises(FileDownloadError) as exc:
response = downloader.get(_TEST_URL)
mock_get.assert_called_once_with(_TEST_URL)

@ -1,129 +1,192 @@
import uuid import uuid
from typing import NamedTuple
from unittest import mock from unittest import mock
import pydantic import httpx
import pytest import pytest
from sqlalchemy import Engine from sqlalchemy import Engine
from core.file import FileTransferMethod, FileType, models from core.file import FileTransferMethod, FileType, models
from core.helper import ssrf_proxy
from core.tools import signature from core.tools import signature
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from core.workflow.nodes.llm.file_saver import MultiModalFile, StorageFileSaver from core.workflow.nodes.llm.file_saver import (
FileSaverImpl,
_extract_content_type_and_extension,
_get_extension,
_validate_extension_override,
)
from models import ToolFile from models import ToolFile
_PNG_DATA = b"\x89PNG\r\n\x1a\n" _PNG_DATA = b"\x89PNG\r\n\x1a\n"
#
# class _MockToolFileManager:
# def __init__(self, mock_tool_file: ToolFile, mock_signed_url: str): def _gen_id():
# self._mock_tool_file = mock_tool_file
# self._mock_signed_url = mock_signed_url
#
# @staticmethod
# def create_file_by_raw(
# *,
# user_id: str,
# tenant_id: str,
# conversation_id: Optional[str],
# file_binary: bytes,
# mimetype: str,
# filename: Optional[str] = None,
# ) -> ToolFile:
# return self._mock_tool_file
#
# @staticmethod
# def sign_file(tool_file_id: str, extension: str) -> str:
# return ""
def test_storage_file_saver(monkeypatch):
def gen_id():
return str(uuid.uuid4()) return str(uuid.uuid4())
mmf = MultiModalFile(
user_id=gen_id(), class TestFileSaverImpl:
tenant_id=gen_id(), def test_save_binary_string(self, monkeypatch):
file_type=FileType.IMAGE, user_id = _gen_id()
data=_PNG_DATA, tenant_id = _gen_id()
mime_type="image/png", file_type = FileType.IMAGE
extension_override=None, mime_type = "image/png"
)
mock_signed_url = "https://example.com/image.png" mock_signed_url = "https://example.com/image.png"
mock_tool_file = ToolFile( mock_tool_file = ToolFile(
id=gen_id(), id=_gen_id(),
user_id=mmf.user_id, user_id=user_id,
tenant_id=mmf.tenant_id, tenant_id=tenant_id,
conversation_id=None, conversation_id=None,
file_key="test-file-key", file_key="test-file-key",
mimetype=mmf.mime_type, mimetype=mime_type,
original_url=None, original_url=None,
name=f"{gen_id()}.png", name=f"{_gen_id()}.png",
size=len(mmf.data), size=len(_PNG_DATA),
) )
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
mocked_engine = mock.MagicMock(spec=Engine) mocked_engine = mock.MagicMock(spec=Engine)
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
monkeypatch.setattr(StorageFileSaver, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
mocked_sign_file.return_value = mock_signed_url mocked_sign_file.return_value = mock_signed_url
storage_file_manager = StorageFileSaver(engine_factory=lambda: mocked_engine) storage_file_manager = FileSaverImpl(
user_id=user_id,
tenant_id=tenant_id,
engine_factory=mocked_engine,
)
file = storage_file_manager.save_file(mmf) file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
assert file.tenant_id == mmf.tenant_id assert file.tenant_id == tenant_id
assert file.type == mmf.file_type assert file.type == file_type
assert file.transfer_method == FileTransferMethod.TOOL_FILE assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.extension == mmf.get_extension() assert file.extension == ".png"
assert file.mime_type == mmf.mime_type assert file.mime_type == mime_type
assert file.size == len(mmf.data) assert file.size == len(_PNG_DATA)
assert file.related_id == mock_tool_file.id assert file.related_id == mock_tool_file.id
assert file.generate_url() == mock_signed_url assert file.generate_url() == mock_signed_url
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
user_id=mmf.user_id, user_id=user_id,
tenant_id=mmf.tenant_id, tenant_id=tenant_id,
conversation_id=None, conversation_id=None,
file_binary=mmf.data, file_binary=_PNG_DATA,
mimetype=mmf.mime_type, mimetype=mime_type,
)
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
def test_save_remote_url_request_failed(self, monkeypatch):
_TEST_URL = "https://example.com/image.png"
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=401,
request=mock_request,
)
file_saver = FileSaverImpl(
user_id=_gen_id(),
tenant_id=_gen_id(),
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
with pytest.raises(httpx.HTTPStatusError) as exc:
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_get.assert_called_once_with(_TEST_URL)
assert exc.value.response.status_code == 401
def test_save_remote_url_success(self, monkeypatch):
_TEST_URL = "https://example.com/image.png"
mime_type = "image/png"
user_id = _gen_id()
tenant_id = _gen_id()
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=200,
content=b"test-data",
headers={"Content-Type": mime_type},
request=mock_request,
) )
mocked_sign_file.assert_called_once_with(mock_tool_file.id, mmf.get_extension())
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mime_type,
original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_save_binary_string.assert_called_once_with(
mock_response.content,
mime_type,
FileType.IMAGE,
extension_override=".png",
)
assert file == mock_tool_file
def test_multi_modal_file_extension_override():
# Test should pass if `extension_override` is not set.
MultiModalFile(user_id="", tenant_id="", file_type=FileType.IMAGE, data=b"", mime_type="image/png")
# Test should pass if `extension_override` is explicitly set to `None`. def test_validate_extension_override():
MultiModalFile( class TestCase(NamedTuple):
user_id="", tenant_id="", file_type=FileType.IMAGE, data=b"", mime_type="image/png", extension_override=None extension_override: str | None
) expected: str | None
# Test should pass if `extension_override` is a string prefixed with `.`. cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
for extension_override in [".png", ".tar.gz"]:
MultiModalFile(
user_id="",
tenant_id="",
file_type=FileType.IMAGE,
data=b"",
mime_type="image/png",
extension_override=extension_override,
)
for invalid_ext_override in ["png", "tar.gz"]: for valid_ext_override in [None, "", ".png", ".tar.gz"]:
with pytest.raises(pydantic.ValidationError) as exc: assert valid_ext_override == _validate_extension_override(valid_ext_override)
MultiModalFile(
user_id="",
tenant_id="",
file_type=FileType.IMAGE,
data=b"",
mime_type="image/png",
extension_override=invalid_ext_override,
)
error_details = exc.value.errors() for invalid_ext_override in ["png", "tar.gz"]:
assert exc.value.error_count() == 1 with pytest.raises(ValueError) as exc:
assert error_details[0]["loc"] == ("extension_override",) _validate_extension_override(invalid_ext_override)
class TestExtractContentTypeAndExtension:
def test_with_both_content_type_and_extension(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_url_with_file_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
assert content_type == "image/png"
assert extension == ".png"
def test_response_with_content_type(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_no_content_type_and_no_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
assert content_type == "application/octet-stream"
assert extension == ".bin"
class TestGetExtension:
def test_with_extension_override(self):
mime_type = "image/png"
for override in [".jpg", ""]:
extension = _get_extension(mime_type, override)
assert extension == override
def test_without_extension_override(self):
mime_type = "image/png"
extension = _get_extension(mime_type)
assert extension == ".png"

@ -1,10 +1,10 @@
from unittest import mock
import base64 import base64
import pytest
import uuid import uuid
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional from typing import Optional
from unittest import mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -33,9 +33,8 @@ from core.workflow.nodes.llm.entities import (
VisionConfig, VisionConfig,
VisionConfigOptions, VisionConfigOptions,
) )
from core.workflow.nodes.llm.file_downloader import FileDownloader, Response from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.file_saver import MultiModalFile, MultiModalFileSaver from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.node import LLMNode, _extract_content_type_and_extension
from models.enums import UserFrom from models.enums import UserFrom
from models.provider import ProviderType from models.provider import ProviderType
from models.workflow import WorkflowType from models.workflow import WorkflowType
@ -117,8 +116,7 @@ def graph_runtime_state() -> GraphRuntimeState:
def llm_node( def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode: ) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=MultiModalFileSaver) mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
mock_file_downloader = mock.MagicMock(spec=FileDownloader)
node = LLMNode( node = LLMNode(
id="1", id="1",
config={ config={
@ -128,8 +126,7 @@ def llm_node(
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
file_downloader=mock_file_downloader, llm_file_saver=mock_file_saver,
multi_modal_file_saver=mock_file_saver,
) )
return node return node
@ -500,9 +497,8 @@ def test_handle_list_messages_basic(llm_node):
@pytest.fixture @pytest.fixture
def llm_node_for_multimodal( def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, FileDownloader, MultiModalFileSaver]: ) -> tuple[LLMNode, LLMFileSaver]:
mock_file_downloader: FileDownloader = mock.MagicMock(spec=FileDownloader) mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
mock_file_saver: MultiModalFileSaver = mock.MagicMock(spec=MultiModalFileSaver)
node = LLMNode( node = LLMNode(
id="1", id="1",
config={ config={
@ -512,29 +508,14 @@ def llm_node_for_multimodal(
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
file_downloader=mock_file_downloader, llm_file_saver=mock_file_saver,
multi_modal_file_saver=mock_file_saver,
) )
return node, mock_file_downloader, mock_file_saver return node, mock_file_saver
class TestLLMNodeSaveMultiModalImageOutput: class TestLLMNodeSaveMultiModalImageOutput:
def test_llm_node_download_file(self, llm_node_for_multimodal): def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
llm_node, mock_file_downloader, _ = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
mock_response = Response(body=b"test-data", content_type="image/png")
mock_file_downloader.get.return_value = mock_response
file = llm_node._download_file("https://example.com/image.png", FileType.IMAGE)
assert file.user_id == "1"
assert file.tenant_id == "1"
assert file.file_type == FileType.IMAGE
assert file.mime_type == mock_response.content_type
assert file.data == mock_response.body
assert file.get_extension() == ".png"
def test_llm_node_save_inline_output(
self, llm_node_for_multimodal: tuple[LLMNode, FileDownloader, MultiModalFileSaver]
):
llm_node, _, mock_file_saver = llm_node_for_multimodal
content = ImagePromptMessageContent( content = ImagePromptMessageContent(
format="png", format="png",
base64_data=base64.b64encode(b"test-data").decode(), base64_data=base64.b64encode(b"test-data").decode(),
@ -551,24 +532,16 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/png", mime_type="image/png",
size=9, size=9,
) )
mock_file_saver.save_file.return_value = mock_file mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content) file = llm_node._save_multimodal_image_output(content=content)
assert llm_node._file_outputs == [mock_file] assert llm_node._file_outputs == [mock_file]
assert file == mock_file assert file == mock_file
expected_saved_multimodal_file = MultiModalFile( mock_file_saver.save_binary_string.assert_called_once_with(
user_id="1", data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE
tenant_id="1",
file_type=FileType.IMAGE,
data=b"test-data",
mime_type="image/png",
extension_override=None,
) )
mock_file_saver.save_file.assert_called_once_with(expected_saved_multimodal_file)
def test_llm_node_save_url_output( def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
self, llm_node_for_multimodal: tuple[LLMNode, FileDownloader, MultiModalFileSaver] llm_node, mock_file_saver = llm_node_for_multimodal
):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
content = ImagePromptMessageContent( content = ImagePromptMessageContent(
format="png", format="png",
url="https://example.com/image.png", url="https://example.com/image.png",
@ -585,21 +558,11 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/png", mime_type="image/png",
size=9, size=9,
) )
mock_file_downloader.get.return_value = Response(body=b"test-data", content_type="image/png") mock_file_saver.save_remote_url.return_value = mock_file
mock_file_saver.save_file.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content) file = llm_node._save_multimodal_image_output(content=content)
assert llm_node._file_outputs == [mock_file] assert llm_node._file_outputs == [mock_file]
assert file == mock_file assert file == mock_file
expected_saved_multimodal_file = MultiModalFile( mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
user_id="1",
tenant_id="1",
file_type=FileType.IMAGE,
data=b"test-data",
mime_type="image/png",
extension_override=".png",
)
mock_file_downloader.get.assert_called_once_with(content.url)
mock_file_saver.save_file.assert_called_once_with(expected_saved_multimodal_file)
def test_llm_node_image_file_to_markdown(llm_node: LLMNode): def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
@ -609,49 +572,25 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
assert markdown == "![](https://example.com/image.png)" assert markdown == "![](https://example.com/image.png)"
class TestExtractContentTypeAndExtension:
def test_with_both_content_type_and_extension(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_url_with_file_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
assert content_type == "image/png"
assert extension == ".png"
def test_response_with_content_type(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_no_content_type_and_no_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
assert content_type == "application/octet-stream"
assert extension == ".bin"
class TestSaveMultimodalOutputAndConvertResultToMarkdown: class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal): def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world") gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
assert list(gen) == ["hello world"] assert list(gen) == ["hello world"]
mock_file_downloader.get.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_file.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_text_prompt_message_content(self, llm_node_for_multimodal): def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")] [TextPromptMessageContent(data="hello world")]
) )
assert list(gen) == ["hello world"] assert list(gen) == ["hello world"]
mock_file_downloader.get.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_file.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_image_content(self, llm_node_for_multimodal): def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
image_raw_data = b"PNG_DATA" image_raw_data = b"PNG_DATA"
image_b64_data = base64.b64encode(image_raw_data).decode() image_b64_data = base64.b64encode(image_raw_data).decode()
@ -668,7 +607,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
url="https://example.com/test.png", url="https://example.com/test.png",
storage_key="test_storage_key", storage_key="test_storage_key",
) )
mock_file_saver.save_file.return_value = mock_saved_file mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[ [
ImagePromptMessageContent( ImagePromptMessageContent(
@ -680,39 +619,40 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
) )
yielded_strs = list(gen) yielded_strs = list(gen)
assert len(yielded_strs) == 1 assert len(yielded_strs) == 1
# This assertion is somewhat tricky.
expected_file_url = f"http://127.0.0.1:5001/files/tools/{mock_saved_file.related_id}.png" # This assertion requires careful handling.
assert yielded_strs[0].startswith(f"![]({expected_file_url}") # `FILES_URL` settings can vary across environments, which might lead to fragile tests.
#
# Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown,
# we verify that the result includes the markdown image syntax and the expected file URL path.
expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png"
assert yielded_strs[0].startswith(f"![](")
assert expected_file_url_path in yielded_strs[0]
assert yielded_strs[0].endswith(")") assert yielded_strs[0].endswith(")")
mock_file_saver.save_file.assert_called_once_with( mock_file_saver.save_binary_string.assert_called_once_with(
MultiModalFile(
user_id="1",
tenant_id="1",
file_type=FileType.IMAGE,
data=image_raw_data, data=image_raw_data,
mime_type="image/png", mime_type="image/png",
file_type=FileType.IMAGE,
) )
)
mock_file_downloader.assert_not_called()
assert mock_saved_file in llm_node._file_outputs assert mock_saved_file in llm_node._file_outputs
def test_unknown_content_type(self, llm_node_for_multimodal): def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"])) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
assert list(gen) == ["frozenset({'hello world'})"] assert list(gen) == ["frozenset({'hello world'})"]
mock_file_downloader.get.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_file.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal): def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])]) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
assert list(gen) == ["frozenset({'hello world'})"] assert list(gen) == ["frozenset({'hello world'})"]
mock_file_downloader.get.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_file.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal): def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
assert list(gen) == [] assert list(gen) == []
mock_file_downloader.get.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_file.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()

Loading…
Cancel
Save