refactor(api): remove the unnecessary FileDownloader interface.

pull/17372/head
QuantumGhost 1 year ago
parent 2585cb37b3
commit f6ab99cc24

@ -1,42 +0,0 @@
import abc
import typing as tp
import httpx
from pydantic import BaseModel
from core.helper import ssrf_proxy
class FileDownloadError(Exception):
pass
class HTTPStatusError(FileDownloadError):
def __init__(self, message: str, *, status_code: int):
self.status_code = status_code
class Response(BaseModel, frozen=True):
body: bytes
content_type: str | None = None
class FileDownloader(tp.Protocol):
@abc.abstractmethod
def get(self, url) -> Response:
pass
class SSRFProxyFileDownloader(FileDownloader):
def get(self, url) -> Response:
try:
http_response = ssrf_proxy.get(url)
http_response.raise_for_status()
return Response(
body=http_response.content,
content_type=http_response.headers.get("Content-Type"),
)
except httpx.TimeoutException as e:
raise FileDownloadError(f"timeout when downloading file from {url}") from e
except httpx.HTTPStatusError as e:
raise HTTPStatusError(f"Error when downloading file from {url}", status_code=e.response.status_code) from e

@ -1,83 +1,73 @@
import abc
import mimetypes
import typing as tp
from pydantic import BaseModel, field_validator
from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.file import File, FileTransferMethod, FileType
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db
class MultiModalFile(BaseModel):
# user_id records t
user_id: str
# tenant_id
tenant_id: str
file_type: FileType
# data is the contents of the file
data: bytes
# mime_type is the media type of the file, specified by
# rfc6838 (https://datatracker.ietf.org/doc/html/rfc6838)
mime_type: str
# `extension_override` allow the user to manually specify the file extension to use
# while saving this file.
#
# The default value is `None`, which means do not override the file extension and guessing it
# from the `mime_type` attribute while saving the file.
#
# Setting it to values other than `None` means override the file's extension, and
# will bypass the extension guessing when calling `get_extension`.
# Specially, setting it to empty string (`""`) will leave the file extension empty.
#
# When it is not `None` or empty string (`""`), it should be a string beginning with a
# dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
# and `tar.gz` is not.
#
# Users of MultiModalFile should always use `get_extension` to access
# the files extension, instead reading this property directly.
extension_override: str | None = None
def get_extension(self) -> str:
"""get_extension return the extension of file.
If the `extension_override` parameter is set, this method should honor it and
return its value.
class LLMFileSaver(tp.Protocol):
"""LLMFileSaver is responsible for save multimodal output returned by
LLM.
"""
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
"""save_binary_string saves the inline file data returned by LLM.
Currently (2025-04-30), only some of Google Gemini models will return
multimodal output as inline data.
:param data: the contents of the file
:param mime_type: the media type of the file, specified by rfc6838
(https://datatracker.ietf.org/doc/html/rfc6838)
:param file_type: The file type of the inline file.
:param extension_override: Override the auto-detected file extension while saving this file.
The default value is `None`, which means do not override the file extension and guessing it
from the `mime_type` attribute while saving the file.
Setting it to values other than `None` means override the file's extension, and
will bypass the extension guessing saving the file.
Specially, setting it to empty string (`""`) will leave the file extension empty.
When it is not `None` or empty string (`""`), it should be a string beginning with a
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
and `tar.gz` are not.
"""
if (extension := self.extension_override) is not None:
return extension
return mimetypes.guess_extension(self.mime_type) or DEFAULT_EXTENSION
@field_validator("extension_override")
@classmethod
def _validate_extension_override(cls, extension_override: str | None) -> str | None:
# `extension_override` is allow to be `None or `""`.
if not extension_override:
return None
if not extension_override.startswith("."):
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
return extension_override
pass
def save_remote_url(self, url: str, file_type: FileType) -> File:
"""save_remote_url saves the file from a remote url returned by LLM.
Currently (2025-04-30), no model returns multimodel output as a url.
class MultiModalFileSaver(tp.Protocol):
@abc.abstractmethod
def save_file(self, mmf: MultiModalFile) -> File:
:param url: the url of the file.
:param file_type: the file type of the file, check `FileType` enum for reference.
"""
pass
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class StorageFileSaver(MultiModalFileSaver):
class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, engine_factory: EngineFactory | None = None):
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None:
def _factory():
@ -85,30 +75,48 @@ class StorageFileSaver(MultiModalFileSaver):
engine_factory = _factory
self._engine_factory = engine_factory
self._user_id = user_id
self._tenant_id = tenant_id
def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory())
def save_file(self, mmf: MultiModalFile) -> File:
def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
tool_file_manager = self._get_tool_file_manager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=mmf.user_id,
tenant_id=mmf.tenant_id,
user_id=self._user_id,
tenant_id=self._tenant_id,
# TODO(QuantumGhost): what is conversation id?
conversation_id=None,
file_binary=mmf.data,
mimetype=mmf.mime_type,
file_binary=data,
mimetype=mime_type,
)
url = sign_tool_file(tool_file.id, mmf.get_extension())
extension_override = _validate_extension_override(extension_override)
extension = _get_extension(mime_type, extension_override)
url = sign_tool_file(tool_file.id, extension)
return File(
tenant_id=mmf.tenant_id,
type=FileType.IMAGE,
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=tool_file.name,
extension=mmf.get_extension(),
mime_type=mmf.mime_type,
size=len(mmf.data),
extension=extension,
mime_type=mime_type,
size=len(data),
related_id=tool_file.id,
url=url,
# TODO(QuantumGhost): how should I set the following key?
@ -116,3 +124,37 @@ class StorageFileSaver(MultiModalFileSaver):
# What's the purpose of `storage_key` and `dify_model_identity`?
storage_key=tool_file.file_key,
)
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
"""get_extension return the extension of file.
If the `extension_override` parameter is set, this function should honor it and
return its value.
"""
if extension_override is not None:
return extension_override
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
"""_extract_content_type_and_extension tries to
guess content type of file from url and `Content-Type` header in response.
"""
if content_type_header:
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
return content_type_header, extension
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
return content_type, extension
def _validate_extension_override(extension_override: str | None) -> str | None:
# `extension_override` is allow to be `None or `""`.
if extension_override is None:
return None
if extension_override == "":
return ""
if not extension_override.startswith("."):
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
return extension_override

@ -2,7 +2,6 @@ import base64
import io
import json
import logging
import mimetypes
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast
@ -10,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import json_repair
from configs import dify_config
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
@ -98,8 +96,7 @@ from .exc import (
TemplateTypeNotSupportError,
VariableNotFoundError,
)
from .file_downloader import FileDownloader, SSRFProxyFileDownloader
from .file_saver import MultiModalFile, MultiModalFileSaver, StorageFileSaver
from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
@ -117,8 +114,8 @@ class LLMNode(BaseNode[LLMNodeData]):
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_file_downloader: FileDownloader
_multi_modal_file_saver: MultiModalFileSaver
_llm_file_saver: LLMFileSaver
def __init__(
self,
@ -130,8 +127,7 @@ class LLMNode(BaseNode[LLMNodeData]):
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
file_downloader: FileDownloader | None = None,
multi_modal_file_saver: MultiModalFileSaver | None = None,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
@ -144,12 +140,13 @@ class LLMNode(BaseNode[LLMNodeData]):
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if file_downloader is None:
file_downloader = SSRFProxyFileDownloader()
self._file_downloader = file_downloader
if multi_modal_file_saver is None:
multi_modal_file_saver = StorageFileSaver()
self._multi_modal_file_saver = multi_modal_file_saver
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
@ -1035,40 +1032,19 @@ class LLMNode(BaseNode[LLMNodeData]):
Currently, only image files are supported.
"""
# Inject the saver somehow...
_saver = self._multi_modal_file_saver
_saver = self._llm_file_saver
# If this
if content.url != "":
mmf = self._download_file(content.url, FileType.IMAGE)
saved_file = _saver.save_file(mmf)
self._file_outputs.append(saved_file)
return saved_file
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
else:
mmf = MultiModalFile(
user_id=self.user_id,
tenant_id=self.tenant_id,
saved_file = _saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
extension_override=None,
)
saved_file = _saver.save_file(mmf)
self._file_outputs.append(saved_file)
return saved_file
def _download_file(self, url: str, file_type: FileType) -> MultiModalFile:
downloader = self._file_downloader
# try to download image
response = downloader.get(url)
content_type, extension = _extract_content_type_and_extension(url, response.content_type)
return MultiModalFile(
user_id=self.user_id,
tenant_id=self.tenant_id,
file_type=file_type,
data=response.body,
mime_type=content_type,
extension_override=extension,
)
self._file_outputs.append(saved_file)
return saved_file
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
"""
@ -1451,15 +1427,3 @@ def convert_boolean_to_string(schema: dict) -> None:
for item in value:
if isinstance(item, dict):
convert_boolean_to_string(item)
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
"""_extract_content_type_and_extension tries to
guess content type of file from url and `Content-Type` header in response.
"""
if content_type_header:
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
return content_type_header, extension
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
return content_type, extension

@ -1,56 +0,0 @@
from unittest import mock
import httpx
import pytest
from core.workflow.nodes.llm.file_downloader import (
FileDownloadError,
HTTPStatusError,
Response,
SSRFProxyFileDownloader,
ssrf_proxy,
)
_TEST_URL = "https://example.com"
class TestSSRFProxyFileDownloader:
def test(self, monkeypatch):
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=200,
content=b"test-data",
headers={"Content-Type": "text/plain"},
request=mock_request,
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
downloader = SSRFProxyFileDownloader()
response = downloader.get(_TEST_URL)
mock_get.assert_called_once_with(_TEST_URL)
assert response == Response(body=mock_response.content, content_type="text/plain")
def test_should_raise_when_status_is_not_successful(self, monkeypatch):
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=401,
request=mock_request,
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
downloader = SSRFProxyFileDownloader()
with pytest.raises(HTTPStatusError) as exc:
response = downloader.get(_TEST_URL)
mock_get.assert_called_once_with(_TEST_URL)
assert exc.value.status_code == 401
def test_should_convert_timeout_to_file_download_error(self, monkeypatch):
mock_get = mock.MagicMock(spec=ssrf_proxy.get, side_effect=httpx.TimeoutException("timeout"))
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
downloader = SSRFProxyFileDownloader()
with pytest.raises(FileDownloadError) as exc:
response = downloader.get(_TEST_URL)
mock_get.assert_called_once_with(_TEST_URL)

@ -1,129 +1,192 @@
import uuid
from typing import NamedTuple
from unittest import mock
import pydantic
import httpx
import pytest
from sqlalchemy import Engine
from core.file import FileTransferMethod, FileType, models
from core.helper import ssrf_proxy
from core.tools import signature
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.nodes.llm.file_saver import MultiModalFile, StorageFileSaver
from core.workflow.nodes.llm.file_saver import (
FileSaverImpl,
_extract_content_type_and_extension,
_get_extension,
_validate_extension_override,
)
from models import ToolFile
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
#
# class _MockToolFileManager:
# def __init__(self, mock_tool_file: ToolFile, mock_signed_url: str):
# self._mock_tool_file = mock_tool_file
# self._mock_signed_url = mock_signed_url
#
# @staticmethod
# def create_file_by_raw(
# *,
# user_id: str,
# tenant_id: str,
# conversation_id: Optional[str],
# file_binary: bytes,
# mimetype: str,
# filename: Optional[str] = None,
# ) -> ToolFile:
# return self._mock_tool_file
#
# @staticmethod
# def sign_file(tool_file_id: str, extension: str) -> str:
# return ""
def test_storage_file_saver(monkeypatch):
def gen_id():
return str(uuid.uuid4())
mmf = MultiModalFile(
user_id=gen_id(),
tenant_id=gen_id(),
file_type=FileType.IMAGE,
data=_PNG_DATA,
mime_type="image/png",
extension_override=None,
)
mock_signed_url = "https://example.com/image.png"
mock_tool_file = ToolFile(
id=gen_id(),
user_id=mmf.user_id,
tenant_id=mmf.tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mmf.mime_type,
original_url=None,
name=f"{gen_id()}.png",
size=len(mmf.data),
)
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
mocked_engine = mock.MagicMock(spec=Engine)
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
monkeypatch.setattr(StorageFileSaver, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
mocked_sign_file.return_value = mock_signed_url
storage_file_manager = StorageFileSaver(engine_factory=lambda: mocked_engine)
file = storage_file_manager.save_file(mmf)
assert file.tenant_id == mmf.tenant_id
assert file.type == mmf.file_type
assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.extension == mmf.get_extension()
assert file.mime_type == mmf.mime_type
assert file.size == len(mmf.data)
assert file.related_id == mock_tool_file.id
assert file.generate_url() == mock_signed_url
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
user_id=mmf.user_id,
tenant_id=mmf.tenant_id,
conversation_id=None,
file_binary=mmf.data,
mimetype=mmf.mime_type,
)
mocked_sign_file.assert_called_once_with(mock_tool_file.id, mmf.get_extension())
def test_multi_modal_file_extension_override():
# Test should pass if `extension_override` is not set.
MultiModalFile(user_id="", tenant_id="", file_type=FileType.IMAGE, data=b"", mime_type="image/png")
# Test should pass if `extension_override` is explicitly set to `None`.
MultiModalFile(
user_id="", tenant_id="", file_type=FileType.IMAGE, data=b"", mime_type="image/png", extension_override=None
)
# Test should pass if `extension_override` is a string prefixed with `.`.
for extension_override in [".png", ".tar.gz"]:
MultiModalFile(
user_id="",
tenant_id="",
file_type=FileType.IMAGE,
data=b"",
mime_type="image/png",
extension_override=extension_override,
def _gen_id():
return str(uuid.uuid4())
class TestFileSaverImpl:
def test_save_binary_string(self, monkeypatch):
user_id = _gen_id()
tenant_id = _gen_id()
file_type = FileType.IMAGE
mime_type = "image/png"
mock_signed_url = "https://example.com/image.png"
mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mime_type,
original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
mocked_engine = mock.MagicMock(spec=Engine)
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
mocked_sign_file.return_value = mock_signed_url
storage_file_manager = FileSaverImpl(
user_id=user_id,
tenant_id=tenant_id,
engine_factory=mocked_engine,
)
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
assert file.tenant_id == tenant_id
assert file.type == file_type
assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.extension == ".png"
assert file.mime_type == mime_type
assert file.size == len(_PNG_DATA)
assert file.related_id == mock_tool_file.id
assert file.generate_url() == mock_signed_url
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_binary=_PNG_DATA,
mimetype=mime_type,
)
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
def test_save_remote_url_request_failed(self, monkeypatch):
_TEST_URL = "https://example.com/image.png"
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=401,
request=mock_request,
)
file_saver = FileSaverImpl(
user_id=_gen_id(),
tenant_id=_gen_id(),
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
with pytest.raises(httpx.HTTPStatusError) as exc:
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_get.assert_called_once_with(_TEST_URL)
assert exc.value.response.status_code == 401
def test_save_remote_url_success(self, monkeypatch):
_TEST_URL = "https://example.com/image.png"
mime_type = "image/png"
user_id = _gen_id()
tenant_id = _gen_id()
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=200,
content=b"test-data",
headers={"Content-Type": mime_type},
request=mock_request,
)
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mime_type,
original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_save_binary_string.assert_called_once_with(
mock_response.content,
mime_type,
FileType.IMAGE,
extension_override=".png",
)
assert file == mock_tool_file
def test_validate_extension_override():
class TestCase(NamedTuple):
extension_override: str | None
expected: str | None
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
assert valid_ext_override == _validate_extension_override(valid_ext_override)
for invalid_ext_override in ["png", "tar.gz"]:
with pytest.raises(pydantic.ValidationError) as exc:
MultiModalFile(
user_id="",
tenant_id="",
file_type=FileType.IMAGE,
data=b"",
mime_type="image/png",
extension_override=invalid_ext_override,
)
error_details = exc.value.errors()
assert exc.value.error_count() == 1
assert error_details[0]["loc"] == ("extension_override",)
with pytest.raises(ValueError) as exc:
_validate_extension_override(invalid_ext_override)
class TestExtractContentTypeAndExtension:
def test_with_both_content_type_and_extension(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_url_with_file_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
assert content_type == "image/png"
assert extension == ".png"
def test_response_with_content_type(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_no_content_type_and_no_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
assert content_type == "application/octet-stream"
assert extension == ".bin"
class TestGetExtension:
def test_with_extension_override(self):
mime_type = "image/png"
for override in [".jpg", ""]:
extension = _get_extension(mime_type, override)
assert extension == override
def test_without_extension_override(self):
mime_type = "image/png"
extension = _get_extension(mime_type)
assert extension == ".png"

@ -1,10 +1,10 @@
from unittest import mock
import base64
import pytest
import uuid
from collections.abc import Sequence
from typing import Optional
from unittest import mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -33,9 +33,8 @@ from core.workflow.nodes.llm.entities import (
VisionConfig,
VisionConfigOptions,
)
from core.workflow.nodes.llm.file_downloader import FileDownloader, Response
from core.workflow.nodes.llm.file_saver import MultiModalFile, MultiModalFileSaver
from core.workflow.nodes.llm.node import LLMNode, _extract_content_type_and_extension
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType
@ -117,8 +116,7 @@ def graph_runtime_state() -> GraphRuntimeState:
def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=MultiModalFileSaver)
mock_file_downloader = mock.MagicMock(spec=FileDownloader)
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
node = LLMNode(
id="1",
config={
@ -128,8 +126,7 @@ def llm_node(
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
file_downloader=mock_file_downloader,
multi_modal_file_saver=mock_file_saver,
llm_file_saver=mock_file_saver,
)
return node
@ -500,9 +497,8 @@ def test_handle_list_messages_basic(llm_node):
@pytest.fixture
def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, FileDownloader, MultiModalFileSaver]:
mock_file_downloader: FileDownloader = mock.MagicMock(spec=FileDownloader)
mock_file_saver: MultiModalFileSaver = mock.MagicMock(spec=MultiModalFileSaver)
) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
node = LLMNode(
id="1",
config={
@ -512,29 +508,14 @@ def llm_node_for_multimodal(
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
file_downloader=mock_file_downloader,
multi_modal_file_saver=mock_file_saver,
llm_file_saver=mock_file_saver,
)
return node, mock_file_downloader, mock_file_saver
return node, mock_file_saver
class TestLLMNodeSaveMultiModalImageOutput:
def test_llm_node_download_file(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, _ = llm_node_for_multimodal
mock_response = Response(body=b"test-data", content_type="image/png")
mock_file_downloader.get.return_value = mock_response
file = llm_node._download_file("https://example.com/image.png", FileType.IMAGE)
assert file.user_id == "1"
assert file.tenant_id == "1"
assert file.file_type == FileType.IMAGE
assert file.mime_type == mock_response.content_type
assert file.data == mock_response.body
assert file.get_extension() == ".png"
def test_llm_node_save_inline_output(
self, llm_node_for_multimodal: tuple[LLMNode, FileDownloader, MultiModalFileSaver]
):
llm_node, _, mock_file_saver = llm_node_for_multimodal
def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
llm_node, mock_file_saver = llm_node_for_multimodal
content = ImagePromptMessageContent(
format="png",
base64_data=base64.b64encode(b"test-data").decode(),
@ -551,24 +532,16 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/png",
size=9,
)
mock_file_saver.save_file.return_value = mock_file
mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
expected_saved_multimodal_file = MultiModalFile(
user_id="1",
tenant_id="1",
file_type=FileType.IMAGE,
data=b"test-data",
mime_type="image/png",
extension_override=None,
mock_file_saver.save_binary_string.assert_called_once_with(
data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE
)
mock_file_saver.save_file.assert_called_once_with(expected_saved_multimodal_file)
def test_llm_node_save_url_output(
self, llm_node_for_multimodal: tuple[LLMNode, FileDownloader, MultiModalFileSaver]
):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
llm_node, mock_file_saver = llm_node_for_multimodal
content = ImagePromptMessageContent(
format="png",
url="https://example.com/image.png",
@ -585,21 +558,11 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/png",
size=9,
)
mock_file_downloader.get.return_value = Response(body=b"test-data", content_type="image/png")
mock_file_saver.save_file.return_value = mock_file
mock_file_saver.save_remote_url.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
expected_saved_multimodal_file = MultiModalFile(
user_id="1",
tenant_id="1",
file_type=FileType.IMAGE,
data=b"test-data",
mime_type="image/png",
extension_override=".png",
)
mock_file_downloader.get.assert_called_once_with(content.url)
mock_file_saver.save_file.assert_called_once_with(expected_saved_multimodal_file)
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
@ -609,49 +572,25 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
assert markdown == "![](https://example.com/image.png)"
class TestExtractContentTypeAndExtension:
def test_with_both_content_type_and_extension(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_url_with_file_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
assert content_type == "image/png"
assert extension == ".png"
def test_response_with_content_type(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_no_content_type_and_no_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
assert content_type == "application/octet-stream"
assert extension == ".bin"
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
assert list(gen) == ["hello world"]
mock_file_downloader.get.assert_not_called()
mock_file_saver.save_file.assert_not_called()
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")]
)
assert list(gen) == ["hello world"]
mock_file_downloader.get.assert_not_called()
mock_file_saver.save_file.assert_not_called()
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_image_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch):
llm_node, mock_file_saver = llm_node_for_multimodal
image_raw_data = b"PNG_DATA"
image_b64_data = base64.b64encode(image_raw_data).decode()
@ -668,7 +607,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
url="https://example.com/test.png",
storage_key="test_storage_key",
)
mock_file_saver.save_file.return_value = mock_saved_file
mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[
ImagePromptMessageContent(
@ -680,39 +619,40 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
)
yielded_strs = list(gen)
assert len(yielded_strs) == 1
# This assertion is somewhat tricky.
expected_file_url = f"http://127.0.0.1:5001/files/tools/{mock_saved_file.related_id}.png"
assert yielded_strs[0].startswith(f"![]({expected_file_url}")
# This assertion requires careful handling.
# `FILES_URL` settings can vary across environments, which might lead to fragile tests.
#
# Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown,
# we verify that the result includes the markdown image syntax and the expected file URL path.
expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png"
assert yielded_strs[0].startswith(f"![](")
assert expected_file_url_path in yielded_strs[0]
assert yielded_strs[0].endswith(")")
mock_file_saver.save_file.assert_called_once_with(
MultiModalFile(
user_id="1",
tenant_id="1",
file_type=FileType.IMAGE,
data=image_raw_data,
mime_type="image/png",
)
mock_file_saver.save_binary_string.assert_called_once_with(
data=image_raw_data,
mime_type="image/png",
file_type=FileType.IMAGE,
)
mock_file_downloader.assert_not_called()
assert mock_saved_file in llm_node._file_outputs
def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_downloader.get.assert_not_called()
mock_file_saver.save_file.assert_not_called()
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_downloader.get.assert_not_called()
mock_file_saver.save_file.assert_not_called()
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_downloader, mock_file_saver = llm_node_for_multimodal
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
assert list(gen) == []
mock_file_downloader.get.assert_not_called()
mock_file_saver.save_file.assert_not_called()
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()

Loading…
Cancel
Save