refactor(api): remove the unnecessary FileDownloader interface.
parent
2585cb37b3
commit
f6ab99cc24
@ -1,42 +0,0 @@
|
||||
import abc
|
||||
import typing as tp
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
|
||||
|
||||
class FileDownloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HTTPStatusError(FileDownloadError):
|
||||
def __init__(self, message: str, *, status_code: int):
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class Response(BaseModel, frozen=True):
|
||||
body: bytes
|
||||
content_type: str | None = None
|
||||
|
||||
|
||||
class FileDownloader(tp.Protocol):
|
||||
@abc.abstractmethod
|
||||
def get(self, url) -> Response:
|
||||
pass
|
||||
|
||||
|
||||
class SSRFProxyFileDownloader(FileDownloader):
|
||||
def get(self, url) -> Response:
|
||||
try:
|
||||
http_response = ssrf_proxy.get(url)
|
||||
http_response.raise_for_status()
|
||||
return Response(
|
||||
body=http_response.content,
|
||||
content_type=http_response.headers.get("Content-Type"),
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
raise FileDownloadError(f"timeout when downloading file from {url}") from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPStatusError(f"Error when downloading file from {url}", status_code=e.response.status_code) from e
|
||||
@ -1,56 +0,0 @@
|
||||
from unittest import mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.llm.file_downloader import (
|
||||
FileDownloadError,
|
||||
HTTPStatusError,
|
||||
Response,
|
||||
SSRFProxyFileDownloader,
|
||||
ssrf_proxy,
|
||||
)
|
||||
|
||||
_TEST_URL = "https://example.com"
|
||||
|
||||
|
||||
class TestSSRFProxyFileDownloader:
|
||||
def test(self, monkeypatch):
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=200,
|
||||
content=b"test-data",
|
||||
headers={"Content-Type": "text/plain"},
|
||||
request=mock_request,
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
|
||||
downloader = SSRFProxyFileDownloader()
|
||||
response = downloader.get(_TEST_URL)
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
assert response == Response(body=mock_response.content, content_type="text/plain")
|
||||
|
||||
def test_should_raise_when_status_is_not_successful(self, monkeypatch):
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=401,
|
||||
request=mock_request,
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
|
||||
downloader = SSRFProxyFileDownloader()
|
||||
with pytest.raises(HTTPStatusError) as exc:
|
||||
response = downloader.get(_TEST_URL)
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
def test_should_convert_timeout_to_file_download_error(self, monkeypatch):
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, side_effect=httpx.TimeoutException("timeout"))
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
|
||||
downloader = SSRFProxyFileDownloader()
|
||||
with pytest.raises(FileDownloadError) as exc:
|
||||
response = downloader.get(_TEST_URL)
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
@ -1,129 +1,192 @@
|
||||
import uuid
|
||||
from typing import NamedTuple
|
||||
from unittest import mock
|
||||
|
||||
import pydantic
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.file import FileTransferMethod, FileType, models
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools import signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.nodes.llm.file_saver import MultiModalFile, StorageFileSaver
|
||||
from core.workflow.nodes.llm.file_saver import (
|
||||
FileSaverImpl,
|
||||
_extract_content_type_and_extension,
|
||||
_get_extension,
|
||||
_validate_extension_override,
|
||||
)
|
||||
from models import ToolFile
|
||||
|
||||
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
|
||||
#
|
||||
# class _MockToolFileManager:
|
||||
# def __init__(self, mock_tool_file: ToolFile, mock_signed_url: str):
|
||||
# self._mock_tool_file = mock_tool_file
|
||||
# self._mock_signed_url = mock_signed_url
|
||||
#
|
||||
# @staticmethod
|
||||
# def create_file_by_raw(
|
||||
# *,
|
||||
# user_id: str,
|
||||
# tenant_id: str,
|
||||
# conversation_id: Optional[str],
|
||||
# file_binary: bytes,
|
||||
# mimetype: str,
|
||||
# filename: Optional[str] = None,
|
||||
# ) -> ToolFile:
|
||||
# return self._mock_tool_file
|
||||
#
|
||||
# @staticmethod
|
||||
# def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
# return ""
|
||||
|
||||
|
||||
def test_storage_file_saver(monkeypatch):
|
||||
def gen_id():
|
||||
|
||||
|
||||
def _gen_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
mmf = MultiModalFile(
|
||||
user_id=gen_id(),
|
||||
tenant_id=gen_id(),
|
||||
file_type=FileType.IMAGE,
|
||||
data=_PNG_DATA,
|
||||
mime_type="image/png",
|
||||
extension_override=None,
|
||||
)
|
||||
|
||||
class TestFileSaverImpl:
|
||||
def test_save_binary_string(self, monkeypatch):
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
file_type = FileType.IMAGE
|
||||
mime_type = "image/png"
|
||||
mock_signed_url = "https://example.com/image.png"
|
||||
mock_tool_file = ToolFile(
|
||||
id=gen_id(),
|
||||
user_id=mmf.user_id,
|
||||
tenant_id=mmf.tenant_id,
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mmf.mime_type,
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{gen_id()}.png",
|
||||
size=len(mmf.data),
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||
mocked_engine = mock.MagicMock(spec=Engine)
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
monkeypatch.setattr(StorageFileSaver, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
|
||||
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
|
||||
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
|
||||
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
|
||||
mocked_sign_file.return_value = mock_signed_url
|
||||
|
||||
storage_file_manager = StorageFileSaver(engine_factory=lambda: mocked_engine)
|
||||
storage_file_manager = FileSaverImpl(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
engine_factory=mocked_engine,
|
||||
)
|
||||
|
||||
file = storage_file_manager.save_file(mmf)
|
||||
assert file.tenant_id == mmf.tenant_id
|
||||
assert file.type == mmf.file_type
|
||||
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||
assert file.tenant_id == tenant_id
|
||||
assert file.type == file_type
|
||||
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||
assert file.extension == mmf.get_extension()
|
||||
assert file.mime_type == mmf.mime_type
|
||||
assert file.size == len(mmf.data)
|
||||
assert file.extension == ".png"
|
||||
assert file.mime_type == mime_type
|
||||
assert file.size == len(_PNG_DATA)
|
||||
assert file.related_id == mock_tool_file.id
|
||||
|
||||
assert file.generate_url() == mock_signed_url
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
|
||||
user_id=mmf.user_id,
|
||||
tenant_id=mmf.tenant_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=_PNG_DATA,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
||||
|
||||
def test_save_remote_url_request_failed(self, monkeypatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=401,
|
||||
request=mock_request,
|
||||
)
|
||||
file_saver = FileSaverImpl(
|
||||
user_id=_gen_id(),
|
||||
tenant_id=_gen_id(),
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc:
|
||||
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
def test_save_remote_url_success(self, monkeypatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mime_type = "image/png"
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=200,
|
||||
content=b"test-data",
|
||||
headers={"Content-Type": mime_type},
|
||||
request=mock_request,
|
||||
)
|
||||
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=mmf.data,
|
||||
mimetype=mmf.mime_type,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mocked_sign_file.assert_called_once_with(mock_tool_file.id, mmf.get_extension())
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
|
||||
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
|
||||
|
||||
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_save_binary_string.assert_called_once_with(
|
||||
mock_response.content,
|
||||
mime_type,
|
||||
FileType.IMAGE,
|
||||
extension_override=".png",
|
||||
)
|
||||
assert file == mock_tool_file
|
||||
|
||||
|
||||
def test_multi_modal_file_extension_override():
|
||||
# Test should pass if `extension_override` is not set.
|
||||
MultiModalFile(user_id="", tenant_id="", file_type=FileType.IMAGE, data=b"", mime_type="image/png")
|
||||
def test_validate_extension_override():
|
||||
class TestCase(NamedTuple):
|
||||
extension_override: str | None
|
||||
expected: str | None
|
||||
|
||||
# Test should pass if `extension_override` is explicitly set to `None`.
|
||||
MultiModalFile(
|
||||
user_id="", tenant_id="", file_type=FileType.IMAGE, data=b"", mime_type="image/png", extension_override=None
|
||||
)
|
||||
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
|
||||
|
||||
# Test should pass if `extension_override` is a string prefixed with `.`.
|
||||
for extension_override in [".png", ".tar.gz"]:
|
||||
MultiModalFile(
|
||||
user_id="",
|
||||
tenant_id="",
|
||||
file_type=FileType.IMAGE,
|
||||
data=b"",
|
||||
mime_type="image/png",
|
||||
extension_override=extension_override,
|
||||
)
|
||||
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
|
||||
assert valid_ext_override == _validate_extension_override(valid_ext_override)
|
||||
|
||||
for invalid_ext_override in ["png", "tar.gz"]:
|
||||
with pytest.raises(pydantic.ValidationError) as exc:
|
||||
MultiModalFile(
|
||||
user_id="",
|
||||
tenant_id="",
|
||||
file_type=FileType.IMAGE,
|
||||
data=b"",
|
||||
mime_type="image/png",
|
||||
extension_override=invalid_ext_override,
|
||||
)
|
||||
|
||||
error_details = exc.value.errors()
|
||||
assert exc.value.error_count() == 1
|
||||
assert error_details[0]["loc"] == ("extension_override",)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_validate_extension_override(invalid_ext_override)
|
||||
|
||||
|
||||
class TestExtractContentTypeAndExtension:
|
||||
def test_with_both_content_type_and_extension(self):
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_url_with_file_extension(self):
|
||||
for content_type in [None, ""]:
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_response_with_content_type(self):
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_no_content_type_and_no_extension(self):
|
||||
for content_type in [None, ""]:
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
|
||||
assert content_type == "application/octet-stream"
|
||||
assert extension == ".bin"
|
||||
|
||||
|
||||
class TestGetExtension:
|
||||
def test_with_extension_override(self):
|
||||
mime_type = "image/png"
|
||||
for override in [".jpg", ""]:
|
||||
extension = _get_extension(mime_type, override)
|
||||
assert extension == override
|
||||
|
||||
def test_without_extension_override(self):
|
||||
mime_type = "image/png"
|
||||
extension = _get_extension(mime_type)
|
||||
assert extension == ".png"
|
||||
|
||||
Loading…
Reference in New Issue