feat(api): Add image multimodal support for LLMNode (#17372)
Enhance `LLMNode` with multimodal capability, introducing support for image outputs. This implementation extracts base64-encoded images from LLM responses, saves them to the storage service, and records the file metadata in the `ToolFile` table. In conversations, these images are rendered as markdown-based inline images. Additionally, the images are included in the LLMNode's output as file variables, enabling subsequent nodes in the workflow to utilize them. To integrate file outputs into workflows, adjustments to the frontend code are necessary. For multimodal output functionality, updates to related model configurations are required. Currently, this capability has been applied exclusively to Google's Gemini models. Close #15814. Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com>pull/19166/head
parent
6c9a9d344a
commit
349c3cf7b8
@ -0,0 +1,7 @@
|
||||
# The two constants below should keep in sync.
|
||||
# Default content type for files which have no explicit content type.
|
||||
|
||||
DEFAULT_MIME_TYPE = "application/octet-stream"
|
||||
# Default file extension for files which have no explicit content type, should
|
||||
# correspond to the `DEFAULT_MIME_TYPE` above.
|
||||
DEFAULT_EXTENSION = ".bin"
|
||||
@ -1,12 +1,19 @@
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
|
||||
|
||||
|
||||
class ToolFileParser:
|
||||
@staticmethod
|
||||
def get_tool_file_manager() -> "ToolFileManager":
|
||||
return cast("ToolFileManager", tool_file_manager["manager"])
|
||||
assert _tool_file_manager_factory is not None
|
||||
return _tool_file_manager_factory()
|
||||
|
||||
|
||||
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
|
||||
global _tool_file_manager_factory
|
||||
_tool_file_manager_factory = factory
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -0,0 +1,160 @@
|
||||
import mimetypes
|
||||
import typing as tp
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
"""LLMFileSaver is responsible for save multimodal output returned by
|
||||
LLM.
|
||||
"""
|
||||
|
||||
def save_binary_string(
|
||||
self,
|
||||
data: bytes,
|
||||
mime_type: str,
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
"""save_binary_string saves the inline file data returned by LLM.
|
||||
|
||||
Currently (2025-04-30), only some of Google Gemini models will return
|
||||
multimodal output as inline data.
|
||||
|
||||
:param data: the contents of the file
|
||||
:param mime_type: the media type of the file, specified by rfc6838
|
||||
(https://datatracker.ietf.org/doc/html/rfc6838)
|
||||
:param file_type: The file type of the inline file.
|
||||
:param extension_override: Override the auto-detected file extension while saving this file.
|
||||
|
||||
The default value is `None`, which means do not override the file extension and guessing it
|
||||
from the `mime_type` attribute while saving the file.
|
||||
|
||||
Setting it to values other than `None` means override the file's extension, and
|
||||
will bypass the extension guessing saving the file.
|
||||
|
||||
Specially, setting it to empty string (`""`) will leave the file extension empty.
|
||||
|
||||
When it is not `None` or empty string (`""`), it should be a string beginning with a
|
||||
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
|
||||
and `tar.gz` are not.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
"""save_remote_url saves the file from a remote url returned by LLM.
|
||||
|
||||
Currently (2025-04-30), no model returns multimodel output as a url.
|
||||
|
||||
:param url: the url of the file.
|
||||
:param file_type: the file type of the file, check `FileType` enum for reference.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
||||
|
||||
class FileSaverImpl(LLMFileSaver):
|
||||
_engine_factory: EngineFactory
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
|
||||
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||
if engine_factory is None:
|
||||
|
||||
def _factory():
|
||||
return global_db.engine
|
||||
|
||||
engine_factory = _factory
|
||||
self._engine_factory = engine_factory
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _get_tool_file_manager(self):
|
||||
return ToolFileManager(engine=self._engine_factory())
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
http_response = ssrf_proxy.get(url)
|
||||
http_response.raise_for_status()
|
||||
data = http_response.content
|
||||
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
|
||||
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
|
||||
|
||||
def save_binary_string(
|
||||
self,
|
||||
data: bytes,
|
||||
mime_type: str,
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
tool_file_manager = self._get_tool_file_manager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
# TODO(QuantumGhost): what is conversation id?
|
||||
conversation_id=None,
|
||||
file_binary=data,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
extension_override = _validate_extension_override(extension_override)
|
||||
extension = _get_extension(mime_type, extension_override)
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
return File(
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=tool_file.name,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(data),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
# TODO(QuantumGhost): how should I set the following key?
|
||||
# What's the difference between `remote_url` and `url`?
|
||||
# What's the purpose of `storage_key` and `dify_model_identity`?
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
|
||||
"""get_extension return the extension of file.
|
||||
|
||||
If the `extension_override` parameter is set, this function should honor it and
|
||||
return its value.
|
||||
"""
|
||||
if extension_override is not None:
|
||||
return extension_override
|
||||
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
|
||||
|
||||
|
||||
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
|
||||
"""_extract_content_type_and_extension tries to
|
||||
guess content type of file from url and `Content-Type` header in response.
|
||||
"""
|
||||
if content_type_header:
|
||||
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
|
||||
return content_type_header, extension
|
||||
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
|
||||
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
|
||||
return content_type, extension
|
||||
|
||||
|
||||
def _validate_extension_override(extension_override: str | None) -> str | None:
|
||||
# `extension_override` is allow to be `None or `""`.
|
||||
if extension_override is None:
|
||||
return None
|
||||
if extension_override == "":
|
||||
return ""
|
||||
if not extension_override.startswith("."):
|
||||
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
|
||||
return extension_override
|
||||
@ -0,0 +1,192 @@
|
||||
import uuid
|
||||
from typing import NamedTuple
|
||||
from unittest import mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.file import FileTransferMethod, FileType, models
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools import signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.nodes.llm.file_saver import (
|
||||
FileSaverImpl,
|
||||
_extract_content_type_and_extension,
|
||||
_get_extension,
|
||||
_validate_extension_override,
|
||||
)
|
||||
from models import ToolFile
|
||||
|
||||
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
|
||||
def _gen_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestFileSaverImpl:
|
||||
def test_save_binary_string(self, monkeypatch):
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
file_type = FileType.IMAGE
|
||||
mime_type = "image/png"
|
||||
mock_signed_url = "https://example.com/image.png"
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||
mocked_engine = mock.MagicMock(spec=Engine)
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
|
||||
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
|
||||
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
|
||||
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
|
||||
mocked_sign_file.return_value = mock_signed_url
|
||||
|
||||
storage_file_manager = FileSaverImpl(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
engine_factory=mocked_engine,
|
||||
)
|
||||
|
||||
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||
assert file.tenant_id == tenant_id
|
||||
assert file.type == file_type
|
||||
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||
assert file.extension == ".png"
|
||||
assert file.mime_type == mime_type
|
||||
assert file.size == len(_PNG_DATA)
|
||||
assert file.related_id == mock_tool_file.id
|
||||
|
||||
assert file.generate_url() == mock_signed_url
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=_PNG_DATA,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
||||
|
||||
def test_save_remote_url_request_failed(self, monkeypatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=401,
|
||||
request=mock_request,
|
||||
)
|
||||
file_saver = FileSaverImpl(
|
||||
user_id=_gen_id(),
|
||||
tenant_id=_gen_id(),
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc:
|
||||
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
def test_save_remote_url_success(self, monkeypatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mime_type = "image/png"
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=200,
|
||||
content=b"test-data",
|
||||
headers={"Content-Type": mime_type},
|
||||
request=mock_request,
|
||||
)
|
||||
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
|
||||
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
|
||||
|
||||
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_save_binary_string.assert_called_once_with(
|
||||
mock_response.content,
|
||||
mime_type,
|
||||
FileType.IMAGE,
|
||||
extension_override=".png",
|
||||
)
|
||||
assert file == mock_tool_file
|
||||
|
||||
|
||||
def test_validate_extension_override():
|
||||
class TestCase(NamedTuple):
|
||||
extension_override: str | None
|
||||
expected: str | None
|
||||
|
||||
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
|
||||
|
||||
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
|
||||
assert valid_ext_override == _validate_extension_override(valid_ext_override)
|
||||
|
||||
for invalid_ext_override in ["png", "tar.gz"]:
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_validate_extension_override(invalid_ext_override)
|
||||
|
||||
|
||||
class TestExtractContentTypeAndExtension:
|
||||
def test_with_both_content_type_and_extension(self):
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_url_with_file_extension(self):
|
||||
for content_type in [None, ""]:
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_response_with_content_type(self):
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_no_content_type_and_no_extension(self):
|
||||
for content_type in [None, ""]:
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
|
||||
assert content_type == "application/octet-stream"
|
||||
assert extension == ".bin"
|
||||
|
||||
|
||||
class TestGetExtension:
|
||||
def test_with_extension_override(self):
|
||||
mime_type = "image/png"
|
||||
for override in [".jpg", ""]:
|
||||
extension = _get_extension(mime_type, override)
|
||||
assert extension == override
|
||||
|
||||
def test_without_extension_override(self):
|
||||
mime_type = "image/png"
|
||||
extension = _get_extension(mime_type)
|
||||
assert extension == ".png"
|
||||
Loading…
Reference in New Issue