refactor(api): remove the unnecessary FileDownloader interface.

pull/17372/head
QuantumGhost 1 year ago
parent 2585cb37b3
commit f6ab99cc24

@ -1,42 +0,0 @@
import abc
import typing as tp
import httpx
from pydantic import BaseModel
from core.helper import ssrf_proxy
class FileDownloadError(Exception):
pass
class HTTPStatusError(FileDownloadError):
def __init__(self, message: str, *, status_code: int):
self.status_code = status_code
class Response(BaseModel, frozen=True):
body: bytes
content_type: str | None = None
class FileDownloader(tp.Protocol):
@abc.abstractmethod
def get(self, url) -> Response:
pass
class SSRFProxyFileDownloader(FileDownloader):
def get(self, url) -> Response:
try:
http_response = ssrf_proxy.get(url)
http_response.raise_for_status()
return Response(
body=http_response.content,
content_type=http_response.headers.get("Content-Type"),
)
except httpx.TimeoutException as e:
raise FileDownloadError(f"timeout when downloading file from {url}") from e
except httpx.HTTPStatusError as e:
raise HTTPStatusError(f"Error when downloading file from {url}", status_code=e.response.status_code) from e

@ -1,83 +1,73 @@
import abc
import mimetypes import mimetypes
import typing as tp import typing as tp
from pydantic import BaseModel, field_validator
from sqlalchemy import Engine from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db from models import db as global_db
class MultiModalFile(BaseModel): class LLMFileSaver(tp.Protocol):
# user_id records t """LLMFileSaver is responsible for save multimodal output returned by
user_id: str LLM.
# tenant_id """
tenant_id: str
file_type: FileType def save_binary_string(
self,
# data is the contents of the file data: bytes,
data: bytes mime_type: str,
file_type: FileType,
# mime_type is the media type of the file, specified by extension_override: str | None = None,
# rfc6838 (https://datatracker.ietf.org/doc/html/rfc6838) ) -> File:
mime_type: str """save_binary_string saves the inline file data returned by LLM.
# `extension_override` allow the user to manually specify the file extension to use Currently (2025-04-30), only some of Google Gemini models will return
# while saving this file. multimodal output as inline data.
#
# The default value is `None`, which means do not override the file extension and guessing it :param data: the contents of the file
# from the `mime_type` attribute while saving the file. :param mime_type: the media type of the file, specified by rfc6838
# (https://datatracker.ietf.org/doc/html/rfc6838)
# Setting it to values other than `None` means override the file's extension, and :param file_type: The file type of the inline file.
# will bypass the extension guessing when calling `get_extension`. :param extension_override: Override the auto-detected file extension while saving this file.
# Specially, setting it to empty string (`""`) will leave the file extension empty.
# The default value is `None`, which means do not override the file extension and guessing it
# When it is not `None` or empty string (`""`), it should be a string beginning with a from the `mime_type` attribute while saving the file.
# dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
# and `tar.gz` is not. Setting it to values other than `None` means override the file's extension, and
# will bypass the extension guessing saving the file.
# Users of MultiModalFile should always use `get_extension` to access
# the files extension, instead reading this property directly. Specially, setting it to empty string (`""`) will leave the file extension empty.
extension_override: str | None = None
When it is not `None` or empty string (`""`), it should be a string beginning with a
def get_extension(self) -> str: dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
"""get_extension return the extension of file. and `tar.gz` are not.
If the `extension_override` parameter is set, this method should honor it and
return its value.
""" """
if (extension := self.extension_override) is not None: pass
return extension
return mimetypes.guess_extension(self.mime_type) or DEFAULT_EXTENSION def save_remote_url(self, url: str, file_type: FileType) -> File:
"""save_remote_url saves the file from a remote url returned by LLM.
@field_validator("extension_override")
@classmethod
def _validate_extension_override(cls, extension_override: str | None) -> str | None:
# `extension_override` is allow to be `None or `""`.
if not extension_override:
return None
if not extension_override.startswith("."):
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
return extension_override
Currently (2025-04-30), no model returns multimodel output as a url.
class MultiModalFileSaver(tp.Protocol): :param url: the url of the file.
@abc.abstractmethod :param file_type: the file type of the file, check `FileType` enum for reference.
def save_file(self, mmf: MultiModalFile) -> File: """
pass pass
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine] EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class StorageFileSaver(MultiModalFileSaver): class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory _engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, engine_factory: EngineFactory | None = None): def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None: if engine_factory is None:
def _factory(): def _factory():
@ -85,30 +75,48 @@ class StorageFileSaver(MultiModalFileSaver):
engine_factory = _factory engine_factory = _factory
self._engine_factory = engine_factory self._engine_factory = engine_factory
self._user_id = user_id
self._tenant_id = tenant_id
def _get_tool_file_manager(self): def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory()) return ToolFileManager(engine=self._engine_factory())
def save_file(self, mmf: MultiModalFile) -> File: def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
tool_file_manager = self._get_tool_file_manager() tool_file_manager = self._get_tool_file_manager()
tool_file = tool_file_manager.create_file_by_raw( tool_file = tool_file_manager.create_file_by_raw(
user_id=mmf.user_id, user_id=self._user_id,
tenant_id=mmf.tenant_id, tenant_id=self._tenant_id,
# TODO(QuantumGhost): what is conversation id? # TODO(QuantumGhost): what is conversation id?
conversation_id=None, conversation_id=None,
file_binary=mmf.data, file_binary=data,
mimetype=mmf.mime_type, mimetype=mime_type,
) )
url = sign_tool_file(tool_file.id, mmf.get_extension()) extension_override = _validate_extension_override(extension_override)
extension = _get_extension(mime_type, extension_override)
url = sign_tool_file(tool_file.id, extension)
return File( return File(
tenant_id=mmf.tenant_id, tenant_id=self._tenant_id,
type=FileType.IMAGE, type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE, transfer_method=FileTransferMethod.TOOL_FILE,
filename=tool_file.name, filename=tool_file.name,
extension=mmf.get_extension(), extension=extension,
mime_type=mmf.mime_type, mime_type=mime_type,
size=len(mmf.data), size=len(data),
related_id=tool_file.id, related_id=tool_file.id,
url=url, url=url,
# TODO(QuantumGhost): how should I set the following key? # TODO(QuantumGhost): how should I set the following key?
@ -116,3 +124,37 @@ class StorageFileSaver(MultiModalFileSaver):
# What's the purpose of `storage_key` and `dify_model_identity`? # What's the purpose of `storage_key` and `dify_model_identity`?
storage_key=tool_file.file_key, storage_key=tool_file.file_key,
) )
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
"""get_extension return the extension of file.
If the `extension_override` parameter is set, this function should honor it and
return its value.
"""
if extension_override is not None:
return extension_override
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
"""_extract_content_type_and_extension tries to
guess content type of file from url and `Content-Type` header in response.
"""
if content_type_header:
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
return content_type_header, extension
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
return content_type, extension
def _validate_extension_override(extension_override: str | None) -> str | None:
# `extension_override` is allow to be `None or `""`.
if extension_override is None:
return None
if extension_override == "":
return ""
if not extension_override.startswith("."):
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
return extension_override

@ -2,7 +2,6 @@ import base64
import io import io
import json import json
import logging import logging
import mimetypes
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
@ -10,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import json_repair import json_repair
from configs import dify_config from configs import dify_config
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
@ -98,8 +96,7 @@ from .exc import (
TemplateTypeNotSupportError, TemplateTypeNotSupportError,
VariableNotFoundError, VariableNotFoundError,
) )
from .file_downloader import FileDownloader, SSRFProxyFileDownloader from .file_saver import FileSaverImpl, LLMFileSaver
from .file_saver import MultiModalFile, MultiModalFileSaver, StorageFileSaver
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
@ -117,8 +114,8 @@ class LLMNode(BaseNode[LLMNodeData]):
# Instance attributes specific to LLMNode. # Instance attributes specific to LLMNode.
# Output variable for file # Output variable for file
_file_outputs: list["File"] _file_outputs: list["File"]
_file_downloader: FileDownloader
_multi_modal_file_saver: MultiModalFileSaver _llm_file_saver: LLMFileSaver
def __init__( def __init__(
self, self,
@ -130,8 +127,7 @@ class LLMNode(BaseNode[LLMNodeData]):
previous_node_id: Optional[str] = None, previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None, thread_pool_id: Optional[str] = None,
*, *,
file_downloader: FileDownloader | None = None, llm_file_saver: LLMFileSaver | None = None,
multi_modal_file_saver: MultiModalFileSaver | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
id=id, id=id,
@ -144,12 +140,13 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
# LLM file outputs, used for MultiModal outputs. # LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = [] self._file_outputs: list[File] = []
if file_downloader is None:
file_downloader = SSRFProxyFileDownloader() if llm_file_saver is None:
self._file_downloader = file_downloader llm_file_saver = FileSaverImpl(
if multi_modal_file_saver is None: user_id=graph_init_params.user_id,
multi_modal_file_saver = StorageFileSaver() tenant_id=graph_init_params.tenant_id,
self._multi_modal_file_saver = multi_modal_file_saver )
self._llm_file_saver = llm_file_saver
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]: def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
@ -1035,40 +1032,19 @@ class LLMNode(BaseNode[LLMNodeData]):
Currently, only image files are supported. Currently, only image files are supported.
""" """
# Inject the saver somehow... # Inject the saver somehow...
_saver = self._multi_modal_file_saver _saver = self._llm_file_saver
# If this # If this
if content.url != "": if content.url != "":
mmf = self._download_file(content.url, FileType.IMAGE) saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
saved_file = _saver.save_file(mmf)
self._file_outputs.append(saved_file)
return saved_file
else: else:
mmf = MultiModalFile( saved_file = _saver.save_binary_string(
user_id=self.user_id,
tenant_id=self.tenant_id,
data=base64.b64decode(content.base64_data), data=base64.b64decode(content.base64_data),
mime_type=content.mime_type, mime_type=content.mime_type,
file_type=FileType.IMAGE, file_type=FileType.IMAGE,
extension_override=None,
) )
saved_file = _saver.save_file(mmf) self._file_outputs.append(saved_file)
self._file_outputs.append(saved_file) return saved_file
return saved_file
def _download_file(self, url: str, file_type: FileType) -> MultiModalFile:
downloader = self._file_downloader
# try to download image
response = downloader.get(url)
content_type, extension = _extract_content_type_and_extension(url, response.content_type)
return MultiModalFile(
user_id=self.user_id,
tenant_id=self.tenant_id,
file_type=file_type,
data=response.body,
mime_type=content_type,
extension_override=extension,
)
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
""" """
@ -1451,15 +1427,3 @@ def convert_boolean_to_string(schema: dict) -> None:
for item in value: for item in value:
if isinstance(item, dict): if isinstance(item, dict):
convert_boolean_to_string(item) convert_boolean_to_string(item)
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
"""_extract_content_type_and_extension tries to
guess content type of file from url and `Content-Type` header in response.
"""
if content_type_header:
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
return content_type_header, extension
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
return content_type, extension

@ -1,56 +0,0 @@
from unittest import mock
import httpx
import pytest
from core.workflow.nodes.llm.file_downloader import (
FileDownloadError,
HTTPStatusError,
Response,
SSRFProxyFileDownloader,
ssrf_proxy,
)
_TEST_URL = "https://example.com"
class TestSSRFProxyFileDownloader:
def test(self, monkeypatch):
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=200,
content=b"test-data",
headers={"Content-Type": "text/plain"},
request=mock_request,
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
downloader = SSRFProxyFileDownloader()
response = downloader.get(_TEST_URL)
mock_get.assert_called_once_with(_TEST_URL)
assert response == Response(body=mock_response.content, content_type="text/plain")
def test_should_raise_when_status_is_not_successful(self, monkeypatch):
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=401,
request=mock_request,
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
downloader = SSRFProxyFileDownloader()
with pytest.raises(HTTPStatusError) as exc:
response = downloader.get(_TEST_URL)
mock_get.assert_called_once_with(_TEST_URL)
assert exc.value.status_code == 401
def test_should_convert_timeout_to_file_download_error(self, monkeypatch):
mock_get = mock.MagicMock(spec=ssrf_proxy.get, side_effect=httpx.TimeoutException("timeout"))
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
downloader = SSRFProxyFileDownloader()
with pytest.raises(FileDownloadError) as exc:
response = downloader.get(_TEST_URL)
mock_get.assert_called_once_with(_TEST_URL)

@ -1,129 +1,192 @@
import uuid import uuid
from typing import NamedTuple
from unittest import mock from unittest import mock
import pydantic import httpx
import pytest import pytest
from sqlalchemy import Engine from sqlalchemy import Engine
from core.file import FileTransferMethod, FileType, models from core.file import FileTransferMethod, FileType, models
from core.helper import ssrf_proxy
from core.tools import signature from core.tools import signature
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from core.workflow.nodes.llm.file_saver import MultiModalFile, StorageFileSaver from core.workflow.nodes.llm.file_saver import (
FileSaverImpl,
_extract_content_type_and_extension,
_get_extension,
_validate_extension_override,
)
from models import ToolFile from models import ToolFile
_PNG_DATA = b"\x89PNG\r\n\x1a\n" _PNG_DATA = b"\x89PNG\r\n\x1a\n"
#
# class _MockToolFileManager:
# def __init__(self, mock_tool_file: ToolFile, mock_signed_url: str): def _gen_id():
# self._mock_tool_file = mock_tool_file return str(uuid.uuid4())
# self._mock_signed_url = mock_signed_url
#
# @staticmethod class TestFileSaverImpl:
# def create_file_by_raw( def test_save_binary_string(self, monkeypatch):
# *, user_id = _gen_id()
# user_id: str, tenant_id = _gen_id()
# tenant_id: str, file_type = FileType.IMAGE
# conversation_id: Optional[str], mime_type = "image/png"
# file_binary: bytes, mock_signed_url = "https://example.com/image.png"
# mimetype: str, mock_tool_file = ToolFile(
# filename: Optional[str] = None, id=_gen_id(),
# ) -> ToolFile: user_id=user_id,
# return self._mock_tool_file tenant_id=tenant_id,
# conversation_id=None,
# @staticmethod file_key="test-file-key",
# def sign_file(tool_file_id: str, extension: str) -> str: mimetype=mime_type,
# return "" original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
def test_storage_file_saver(monkeypatch): )
def gen_id(): mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
return str(uuid.uuid4()) mocked_engine = mock.MagicMock(spec=Engine)
mmf = MultiModalFile( mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
user_id=gen_id(), monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
tenant_id=gen_id(), # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
file_type=FileType.IMAGE, mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
data=_PNG_DATA, # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
mime_type="image/png", monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
extension_override=None, mocked_sign_file.return_value = mock_signed_url
)
mock_signed_url = "https://example.com/image.png" storage_file_manager = FileSaverImpl(
mock_tool_file = ToolFile( user_id=user_id,
id=gen_id(), tenant_id=tenant_id,
user_id=mmf.user_id, engine_factory=mocked_engine,
tenant_id=mmf.tenant_id, )
conversation_id=None,
file_key="test-file-key", file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
mimetype=mmf.mime_type, assert file.tenant_id == tenant_id
original_url=None, assert file.type == file_type
name=f"{gen_id()}.png", assert file.transfer_method == FileTransferMethod.TOOL_FILE
size=len(mmf.data), assert file.extension == ".png"
) assert file.mime_type == mime_type
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) assert file.size == len(_PNG_DATA)
mocked_engine = mock.MagicMock(spec=Engine) assert file.related_id == mock_tool_file.id
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file assert file.generate_url() == mock_signed_url
monkeypatch.setattr(StorageFileSaver, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) user_id=user_id,
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. tenant_id=tenant_id,
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) conversation_id=None,
mocked_sign_file.return_value = mock_signed_url file_binary=_PNG_DATA,
mimetype=mime_type,
storage_file_manager = StorageFileSaver(engine_factory=lambda: mocked_engine) )
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
file = storage_file_manager.save_file(mmf)
assert file.tenant_id == mmf.tenant_id def test_save_remote_url_request_failed(self, monkeypatch):
assert file.type == mmf.file_type _TEST_URL = "https://example.com/image.png"
assert file.transfer_method == FileTransferMethod.TOOL_FILE mock_request = httpx.Request("GET", _TEST_URL)
assert file.extension == mmf.get_extension() mock_response = httpx.Response(
assert file.mime_type == mmf.mime_type status_code=401,
assert file.size == len(mmf.data) request=mock_request,
assert file.related_id == mock_tool_file.id )
file_saver = FileSaverImpl(
assert file.generate_url() == mock_signed_url user_id=_gen_id(),
tenant_id=_gen_id(),
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( )
user_id=mmf.user_id, mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
tenant_id=mmf.tenant_id, monkeypatch.setattr(ssrf_proxy, "get", mock_get)
conversation_id=None,
file_binary=mmf.data, with pytest.raises(httpx.HTTPStatusError) as exc:
mimetype=mmf.mime_type, file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
) mock_get.assert_called_once_with(_TEST_URL)
mocked_sign_file.assert_called_once_with(mock_tool_file.id, mmf.get_extension()) assert exc.value.response.status_code == 401
def test_save_remote_url_success(self, monkeypatch):
def test_multi_modal_file_extension_override(): _TEST_URL = "https://example.com/image.png"
# Test should pass if `extension_override` is not set. mime_type = "image/png"
MultiModalFile(user_id="", tenant_id="", file_type=FileType.IMAGE, data=b"", mime_type="image/png") user_id = _gen_id()
tenant_id = _gen_id()
# Test should pass if `extension_override` is explicitly set to `None`.
MultiModalFile( mock_request = httpx.Request("GET", _TEST_URL)
user_id="", tenant_id="", file_type=FileType.IMAGE, data=b"", mime_type="image/png", extension_override=None mock_response = httpx.Response(
) status_code=200,
content=b"test-data",
# Test should pass if `extension_override` is a string prefixed with `.`. headers={"Content-Type": mime_type},
for extension_override in [".png", ".tar.gz"]: request=mock_request,
MultiModalFile(
user_id="",
tenant_id="",
file_type=FileType.IMAGE,
data=b"",
mime_type="image/png",
extension_override=extension_override,
) )
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mime_type,
original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_save_binary_string.assert_called_once_with(
mock_response.content,
mime_type,
FileType.IMAGE,
extension_override=".png",
)
assert file == mock_tool_file
def test_validate_extension_override():
class TestCase(NamedTuple):
extension_override: str | None
expected: str | None
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
assert valid_ext_override == _validate_extension_override(valid_ext_override)
for invalid_ext_override in ["png", "tar.gz"]: for invalid_ext_override in ["png", "tar.gz"]:
with pytest.raises(pydantic.ValidationError) as exc: with pytest.raises(ValueError) as exc:
MultiModalFile( _validate_extension_override(invalid_ext_override)
user_id="",
tenant_id="",
file_type=FileType.IMAGE, class TestExtractContentTypeAndExtension:
data=b"", def test_with_both_content_type_and_extension(self):
mime_type="image/png", content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
extension_override=invalid_ext_override, assert content_type == "image/png"
) assert extension == ".png"
error_details = exc.value.errors() def test_url_with_file_extension(self):
assert exc.value.error_count() == 1 for content_type in [None, ""]:
assert error_details[0]["loc"] == ("extension_override",) content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
assert content_type == "image/png"
assert extension == ".png"
def test_response_with_content_type(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_no_content_type_and_no_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
assert content_type == "application/octet-stream"
assert extension == ".bin"
class TestGetExtension:
def test_with_extension_override(self):
mime_type = "image/png"
for override in [".jpg", ""]:
extension = _get_extension(mime_type, override)
assert extension == override
def test_without_extension_override(self):
mime_type = "image/png"
extension = _get_extension(mime_type)
assert extension == ".png"

@ -1,10 +1,10 @@
from unittest import mock
import base64 import base64
import pytest
import uuid import uuid
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional from typing import Optional
from unittest import mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -33,9 +33,8 @@ from core.workflow.nodes.llm.entities import (
VisionConfig, VisionConfig,
VisionConfigOptions, VisionConfigOptions,
) )
from core.workflow.nodes.llm.file_downloader import FileDownloader, Response from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.file_saver import MultiModalFile, MultiModalFileSaver from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.node import LLMNode, _extract_content_type_and_extension
from models.enums import UserFrom from models.enums import UserFrom
from models.provider import ProviderType from models.provider import ProviderType
from models.workflow import WorkflowType from models.workflow import WorkflowType
@ -117,8 +116,7 @@ def graph_runtime_state() -> GraphRuntimeState:
def llm_node( def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode: ) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=MultiModalFileSaver) mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
mock_file_downloader = mock.MagicMock(spec=FileDownloader)
node = LLMNode( node = LLMNode(
id="1", id="1",
config={ config={
@ -128,8 +126,7 @@ def llm_node(
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
file_downloader=mock_file_downloader, llm_file_saver=mock_file_saver,
multi_modal_file_saver=mock_file_saver,
) )
return node return node
@ -500,9 +497,8 @@ def test_handle_list_messages_basic(llm_node):
@pytest.fixture @pytest.fixture
def llm_node_for_multimodal( def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, FileDownloader, MultiModalFileSaver]: ) -> tuple[LLMNode, LLMFileSaver]:
mock_file_downloader: FileDownloader = mock.MagicMock(spec=FileDownloader) mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
mock_file_saver: MultiModalFileSaver = mock.MagicMock(spec=MultiModalFileSaver)
node = LLMNode( node = LLMNode(
id="1", id="1",
config={ config={
@ -512,29 +508,14 @@ def llm_node_for_multimodal(
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
file_downloader=mock_file_downloader, llm_file_saver=mock_file_saver,
multi_modal_file_saver=mock_file_saver,
) )
return node, mock_file_downloader, mock_file_saver return node, mock_file_saver
class TestLLMNodeSaveMultiModalImageOutput: class TestLLMNodeSaveMultiModalImageOutput:
def test_llm_node_download_file(self, llm_node_for_multimodal): def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
llm_node, mock_file_downloader, _ = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
mock_response = Response(body=b"test-data", content_type="image/png")
mock_file_downloader.get.return_value = mock_response
file = llm_node._download_file("https://example.com/image.png", FileType.IMAGE)
assert file.user_id == "1"
assert file.tenant_id == "1"
assert file.file_type == FileType.IMAGE
assert file.mime_type == mock_response.content_type
assert file.data == mock_response.body
assert file.get_extension() == ".png"
def test_llm_node_save_inline_output(
self, llm_node_for_multimodal: tuple[LLMNode, FileDownloader, MultiModalFileSaver]
):
llm_node, _, mock_file_saver = llm_node_for_multimodal
content = ImagePromptMessageContent( content = ImagePromptMessageContent(
format="png", format="png",
base64_data=base64.b64encode(b"test-data").decode(), base64_data=base64.b64encode(b"test-data").decode(),
@ -551,24 +532,16 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/png", mime_type="image/png",
size=9, size=9,
) )
mock_file_saver.save_file.return_value = mock_file mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content) file = llm_node._save_multimodal_image_output(content=content)
assert llm_node._file_outputs == [mock_file] assert llm_node._file_outputs == [mock_file]
assert file == mock_file assert file == mock_file
expected_saved_multimodal_file = MultiModalFile( mock_file_saver.save_binary_string.assert_called_once_with(
user_id="1", data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE
tenant_id="1",
file_type=FileType.IMAGE,
data=b"test-data",
mime_type="image/png",
extension_override=None,
) )
mock_file_saver.save_file.assert_called_once_with(expected_saved_multimodal_file)
def test_llm_node_save_url_output( def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
self, llm_node_for_multimodal: tuple[LLMNode, FileDownloader, MultiModalFileSaver] llm_node, mock_file_saver = llm_node_for_multimodal
):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
content = ImagePromptMessageContent( content = ImagePromptMessageContent(
format="png", format="png",
url="https://example.com/image.png", url="https://example.com/image.png",
@ -585,21 +558,11 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/png", mime_type="image/png",
size=9, size=9,
) )
mock_file_downloader.get.return_value = Response(body=b"test-data", content_type="image/png") mock_file_saver.save_remote_url.return_value = mock_file
mock_file_saver.save_file.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content) file = llm_node._save_multimodal_image_output(content=content)
assert llm_node._file_outputs == [mock_file] assert llm_node._file_outputs == [mock_file]
assert file == mock_file assert file == mock_file
expected_saved_multimodal_file = MultiModalFile( mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
user_id="1",
tenant_id="1",
file_type=FileType.IMAGE,
data=b"test-data",
mime_type="image/png",
extension_override=".png",
)
mock_file_downloader.get.assert_called_once_with(content.url)
mock_file_saver.save_file.assert_called_once_with(expected_saved_multimodal_file)
def test_llm_node_image_file_to_markdown(llm_node: LLMNode): def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
@ -609,49 +572,25 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
assert markdown == "![](https://example.com/image.png)" assert markdown == "![](https://example.com/image.png)"
class TestExtractContentTypeAndExtension:
def test_with_both_content_type_and_extension(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_url_with_file_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
assert content_type == "image/png"
assert extension == ".png"
def test_response_with_content_type(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_no_content_type_and_no_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
assert content_type == "application/octet-stream"
assert extension == ".bin"
class TestSaveMultimodalOutputAndConvertResultToMarkdown: class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal): def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world") gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
assert list(gen) == ["hello world"] assert list(gen) == ["hello world"]
mock_file_downloader.get.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_file.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_text_prompt_message_content(self, llm_node_for_multimodal): def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")] [TextPromptMessageContent(data="hello world")]
) )
assert list(gen) == ["hello world"] assert list(gen) == ["hello world"]
mock_file_downloader.get.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_file.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_image_content(self, llm_node_for_multimodal): def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
image_raw_data = b"PNG_DATA" image_raw_data = b"PNG_DATA"
image_b64_data = base64.b64encode(image_raw_data).decode() image_b64_data = base64.b64encode(image_raw_data).decode()
@ -668,7 +607,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
url="https://example.com/test.png", url="https://example.com/test.png",
storage_key="test_storage_key", storage_key="test_storage_key",
) )
mock_file_saver.save_file.return_value = mock_saved_file mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[ [
ImagePromptMessageContent( ImagePromptMessageContent(
@ -680,39 +619,40 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
) )
yielded_strs = list(gen) yielded_strs = list(gen)
assert len(yielded_strs) == 1 assert len(yielded_strs) == 1
# This assertion is somewhat tricky.
expected_file_url = f"http://127.0.0.1:5001/files/tools/{mock_saved_file.related_id}.png" # This assertion requires careful handling.
assert yielded_strs[0].startswith(f"![]({expected_file_url}") # `FILES_URL` settings can vary across environments, which might lead to fragile tests.
#
# Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown,
# we verify that the result includes the markdown image syntax and the expected file URL path.
expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png"
assert yielded_strs[0].startswith(f"![](")
assert expected_file_url_path in yielded_strs[0]
assert yielded_strs[0].endswith(")") assert yielded_strs[0].endswith(")")
mock_file_saver.save_file.assert_called_once_with( mock_file_saver.save_binary_string.assert_called_once_with(
MultiModalFile( data=image_raw_data,
user_id="1", mime_type="image/png",
tenant_id="1", file_type=FileType.IMAGE,
file_type=FileType.IMAGE,
data=image_raw_data,
mime_type="image/png",
)
) )
mock_file_downloader.assert_not_called()
assert mock_saved_file in llm_node._file_outputs assert mock_saved_file in llm_node._file_outputs
def test_unknown_content_type(self, llm_node_for_multimodal): def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"])) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
assert list(gen) == ["frozenset({'hello world'})"] assert list(gen) == ["frozenset({'hello world'})"]
mock_file_downloader.get.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_file.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal): def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])]) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
assert list(gen) == ["frozenset({'hello world'})"] assert list(gen) == ["frozenset({'hello world'})"]
mock_file_downloader.get.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_file.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal): def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
assert list(gen) == [] assert list(gen) == []
mock_file_downloader.get.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_file.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()

Loading…
Cancel
Save