Merge branch 'main' into e-300
commit
0301bd3ac1
@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"Verbose": false,
|
||||||
|
"Debug": false,
|
||||||
|
"IgnoreDefaults": false,
|
||||||
|
"SpacesAfterTabs": false,
|
||||||
|
"NoColor": false,
|
||||||
|
"Exclude": [
|
||||||
|
"^web/public/vs/",
|
||||||
|
"^web/public/pdf.worker.min.mjs$",
|
||||||
|
"web/app/components/base/icons/src/vender/"
|
||||||
|
],
|
||||||
|
"AllowedContentTypes": [],
|
||||||
|
"PassedFiles": [],
|
||||||
|
"Disable": {
|
||||||
|
"EndOfLine": false,
|
||||||
|
"Indentation": false,
|
||||||
|
"IndentSize": true,
|
||||||
|
"InsertFinalNewline": false,
|
||||||
|
"TrimTrailingWhitespace": false,
|
||||||
|
"MaxLineLength": false
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
# The two constants below should keep in sync.
|
||||||
|
# Default content type for files which have no explicit content type.
|
||||||
|
|
||||||
|
DEFAULT_MIME_TYPE = "application/octet-stream"
|
||||||
|
# Default file extension for files which have no explicit content type, should
|
||||||
|
# correspond to the `DEFAULT_MIME_TYPE` above.
|
||||||
|
DEFAULT_EXTENSION = ".bin"
|
||||||
@ -1,12 +1,19 @@
|
|||||||
from typing import TYPE_CHECKING, Any, cast
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
||||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ToolFileParser:
|
class ToolFileParser:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tool_file_manager() -> "ToolFileManager":
|
def get_tool_file_manager() -> "ToolFileManager":
|
||||||
return cast("ToolFileManager", tool_file_manager["manager"])
|
assert _tool_file_manager_factory is not None
|
||||||
|
return _tool_file_manager_factory()
|
||||||
|
|
||||||
|
|
||||||
|
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
|
||||||
|
global _tool_file_manager_factory
|
||||||
|
_tool_file_manager_factory = factory
|
||||||
|
|||||||
@ -0,0 +1,41 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
||||||
|
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||||
|
"""
|
||||||
|
sign file to get a temporary url
|
||||||
|
"""
|
||||||
|
base_url = dify_config.FILES_URL
|
||||||
|
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
"""
|
||||||
|
verify signature
|
||||||
|
"""
|
||||||
|
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||||
@ -0,0 +1,160 @@
|
|||||||
|
import mimetypes
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
from sqlalchemy import Engine
|
||||||
|
|
||||||
|
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||||
|
from core.file import File, FileTransferMethod, FileType
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from core.tools.signature import sign_tool_file
|
||||||
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
from models import db as global_db
|
||||||
|
|
||||||
|
|
||||||
|
class LLMFileSaver(tp.Protocol):
|
||||||
|
"""LLMFileSaver is responsible for save multimodal output returned by
|
||||||
|
LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save_binary_string(
|
||||||
|
self,
|
||||||
|
data: bytes,
|
||||||
|
mime_type: str,
|
||||||
|
file_type: FileType,
|
||||||
|
extension_override: str | None = None,
|
||||||
|
) -> File:
|
||||||
|
"""save_binary_string saves the inline file data returned by LLM.
|
||||||
|
|
||||||
|
Currently (2025-04-30), only some of Google Gemini models will return
|
||||||
|
multimodal output as inline data.
|
||||||
|
|
||||||
|
:param data: the contents of the file
|
||||||
|
:param mime_type: the media type of the file, specified by rfc6838
|
||||||
|
(https://datatracker.ietf.org/doc/html/rfc6838)
|
||||||
|
:param file_type: The file type of the inline file.
|
||||||
|
:param extension_override: Override the auto-detected file extension while saving this file.
|
||||||
|
|
||||||
|
The default value is `None`, which means do not override the file extension and guessing it
|
||||||
|
from the `mime_type` attribute while saving the file.
|
||||||
|
|
||||||
|
Setting it to values other than `None` means override the file's extension, and
|
||||||
|
will bypass the extension guessing saving the file.
|
||||||
|
|
||||||
|
Specially, setting it to empty string (`""`) will leave the file extension empty.
|
||||||
|
|
||||||
|
When it is not `None` or empty string (`""`), it should be a string beginning with a
|
||||||
|
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
|
||||||
|
and `tar.gz` are not.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||||
|
"""save_remote_url saves the file from a remote url returned by LLM.
|
||||||
|
|
||||||
|
Currently (2025-04-30), no model returns multimodel output as a url.
|
||||||
|
|
||||||
|
:param url: the url of the file.
|
||||||
|
:param file_type: the file type of the file, check `FileType` enum for reference.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||||
|
|
||||||
|
|
||||||
|
class FileSaverImpl(LLMFileSaver):
|
||||||
|
_engine_factory: EngineFactory
|
||||||
|
_tenant_id: str
|
||||||
|
_user_id: str
|
||||||
|
|
||||||
|
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||||
|
if engine_factory is None:
|
||||||
|
|
||||||
|
def _factory():
|
||||||
|
return global_db.engine
|
||||||
|
|
||||||
|
engine_factory = _factory
|
||||||
|
self._engine_factory = engine_factory
|
||||||
|
self._user_id = user_id
|
||||||
|
self._tenant_id = tenant_id
|
||||||
|
|
||||||
|
def _get_tool_file_manager(self):
|
||||||
|
return ToolFileManager(engine=self._engine_factory())
|
||||||
|
|
||||||
|
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||||
|
http_response = ssrf_proxy.get(url)
|
||||||
|
http_response.raise_for_status()
|
||||||
|
data = http_response.content
|
||||||
|
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||||
|
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
|
||||||
|
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
|
||||||
|
|
||||||
|
def save_binary_string(
|
||||||
|
self,
|
||||||
|
data: bytes,
|
||||||
|
mime_type: str,
|
||||||
|
file_type: FileType,
|
||||||
|
extension_override: str | None = None,
|
||||||
|
) -> File:
|
||||||
|
tool_file_manager = self._get_tool_file_manager()
|
||||||
|
tool_file = tool_file_manager.create_file_by_raw(
|
||||||
|
user_id=self._user_id,
|
||||||
|
tenant_id=self._tenant_id,
|
||||||
|
# TODO(QuantumGhost): what is conversation id?
|
||||||
|
conversation_id=None,
|
||||||
|
file_binary=data,
|
||||||
|
mimetype=mime_type,
|
||||||
|
)
|
||||||
|
extension_override = _validate_extension_override(extension_override)
|
||||||
|
extension = _get_extension(mime_type, extension_override)
|
||||||
|
url = sign_tool_file(tool_file.id, extension)
|
||||||
|
|
||||||
|
return File(
|
||||||
|
tenant_id=self._tenant_id,
|
||||||
|
type=file_type,
|
||||||
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||||
|
filename=tool_file.name,
|
||||||
|
extension=extension,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size=len(data),
|
||||||
|
related_id=tool_file.id,
|
||||||
|
url=url,
|
||||||
|
# TODO(QuantumGhost): how should I set the following key?
|
||||||
|
# What's the difference between `remote_url` and `url`?
|
||||||
|
# What's the purpose of `storage_key` and `dify_model_identity`?
|
||||||
|
storage_key=tool_file.file_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
|
||||||
|
"""get_extension return the extension of file.
|
||||||
|
|
||||||
|
If the `extension_override` parameter is set, this function should honor it and
|
||||||
|
return its value.
|
||||||
|
"""
|
||||||
|
if extension_override is not None:
|
||||||
|
return extension_override
|
||||||
|
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
|
||||||
|
"""_extract_content_type_and_extension tries to
|
||||||
|
guess content type of file from url and `Content-Type` header in response.
|
||||||
|
"""
|
||||||
|
if content_type_header:
|
||||||
|
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
|
||||||
|
return content_type_header, extension
|
||||||
|
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
|
||||||
|
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
|
||||||
|
return content_type, extension
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_extension_override(extension_override: str | None) -> str | None:
|
||||||
|
# `extension_override` is allow to be `None or `""`.
|
||||||
|
if extension_override is None:
|
||||||
|
return None
|
||||||
|
if extension_override == "":
|
||||||
|
return ""
|
||||||
|
if not extension_override.startswith("."):
|
||||||
|
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
|
||||||
|
return extension_override
|
||||||
@ -0,0 +1,192 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import NamedTuple
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import Engine
|
||||||
|
|
||||||
|
from core.file import FileTransferMethod, FileType, models
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from core.tools import signature
|
||||||
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
from core.workflow.nodes.llm.file_saver import (
|
||||||
|
FileSaverImpl,
|
||||||
|
_extract_content_type_and_extension,
|
||||||
|
_get_extension,
|
||||||
|
_validate_extension_override,
|
||||||
|
)
|
||||||
|
from models import ToolFile
|
||||||
|
|
||||||
|
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_id():
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileSaverImpl:
|
||||||
|
def test_save_binary_string(self, monkeypatch):
|
||||||
|
user_id = _gen_id()
|
||||||
|
tenant_id = _gen_id()
|
||||||
|
file_type = FileType.IMAGE
|
||||||
|
mime_type = "image/png"
|
||||||
|
mock_signed_url = "https://example.com/image.png"
|
||||||
|
mock_tool_file = ToolFile(
|
||||||
|
id=_gen_id(),
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=None,
|
||||||
|
file_key="test-file-key",
|
||||||
|
mimetype=mime_type,
|
||||||
|
original_url=None,
|
||||||
|
name=f"{_gen_id()}.png",
|
||||||
|
size=len(_PNG_DATA),
|
||||||
|
)
|
||||||
|
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||||
|
mocked_engine = mock.MagicMock(spec=Engine)
|
||||||
|
|
||||||
|
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
|
||||||
|
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||||
|
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
|
||||||
|
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
|
||||||
|
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
|
||||||
|
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
|
||||||
|
mocked_sign_file.return_value = mock_signed_url
|
||||||
|
|
||||||
|
storage_file_manager = FileSaverImpl(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
engine_factory=mocked_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||||
|
assert file.tenant_id == tenant_id
|
||||||
|
assert file.type == file_type
|
||||||
|
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||||
|
assert file.extension == ".png"
|
||||||
|
assert file.mime_type == mime_type
|
||||||
|
assert file.size == len(_PNG_DATA)
|
||||||
|
assert file.related_id == mock_tool_file.id
|
||||||
|
|
||||||
|
assert file.generate_url() == mock_signed_url
|
||||||
|
|
||||||
|
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=None,
|
||||||
|
file_binary=_PNG_DATA,
|
||||||
|
mimetype=mime_type,
|
||||||
|
)
|
||||||
|
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
||||||
|
|
||||||
|
def test_save_remote_url_request_failed(self, monkeypatch):
|
||||||
|
_TEST_URL = "https://example.com/image.png"
|
||||||
|
mock_request = httpx.Request("GET", _TEST_URL)
|
||||||
|
mock_response = httpx.Response(
|
||||||
|
status_code=401,
|
||||||
|
request=mock_request,
|
||||||
|
)
|
||||||
|
file_saver = FileSaverImpl(
|
||||||
|
user_id=_gen_id(),
|
||||||
|
tenant_id=_gen_id(),
|
||||||
|
)
|
||||||
|
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||||
|
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||||
|
|
||||||
|
with pytest.raises(httpx.HTTPStatusError) as exc:
|
||||||
|
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||||
|
mock_get.assert_called_once_with(_TEST_URL)
|
||||||
|
assert exc.value.response.status_code == 401
|
||||||
|
|
||||||
|
def test_save_remote_url_success(self, monkeypatch):
|
||||||
|
_TEST_URL = "https://example.com/image.png"
|
||||||
|
mime_type = "image/png"
|
||||||
|
user_id = _gen_id()
|
||||||
|
tenant_id = _gen_id()
|
||||||
|
|
||||||
|
mock_request = httpx.Request("GET", _TEST_URL)
|
||||||
|
mock_response = httpx.Response(
|
||||||
|
status_code=200,
|
||||||
|
content=b"test-data",
|
||||||
|
headers={"Content-Type": mime_type},
|
||||||
|
request=mock_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
|
||||||
|
mock_tool_file = ToolFile(
|
||||||
|
id=_gen_id(),
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=None,
|
||||||
|
file_key="test-file-key",
|
||||||
|
mimetype=mime_type,
|
||||||
|
original_url=None,
|
||||||
|
name=f"{_gen_id()}.png",
|
||||||
|
size=len(_PNG_DATA),
|
||||||
|
)
|
||||||
|
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||||
|
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||||
|
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
|
||||||
|
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
|
||||||
|
|
||||||
|
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||||
|
mock_save_binary_string.assert_called_once_with(
|
||||||
|
mock_response.content,
|
||||||
|
mime_type,
|
||||||
|
FileType.IMAGE,
|
||||||
|
extension_override=".png",
|
||||||
|
)
|
||||||
|
assert file == mock_tool_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_extension_override():
|
||||||
|
class TestCase(NamedTuple):
|
||||||
|
extension_override: str | None
|
||||||
|
expected: str | None
|
||||||
|
|
||||||
|
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
|
||||||
|
|
||||||
|
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
|
||||||
|
assert valid_ext_override == _validate_extension_override(valid_ext_override)
|
||||||
|
|
||||||
|
for invalid_ext_override in ["png", "tar.gz"]:
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
_validate_extension_override(invalid_ext_override)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractContentTypeAndExtension:
|
||||||
|
def test_with_both_content_type_and_extension(self):
|
||||||
|
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
|
||||||
|
assert content_type == "image/png"
|
||||||
|
assert extension == ".png"
|
||||||
|
|
||||||
|
def test_url_with_file_extension(self):
|
||||||
|
for content_type in [None, ""]:
|
||||||
|
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
|
||||||
|
assert content_type == "image/png"
|
||||||
|
assert extension == ".png"
|
||||||
|
|
||||||
|
def test_response_with_content_type(self):
|
||||||
|
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
|
||||||
|
assert content_type == "image/png"
|
||||||
|
assert extension == ".png"
|
||||||
|
|
||||||
|
def test_no_content_type_and_no_extension(self):
|
||||||
|
for content_type in [None, ""]:
|
||||||
|
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
|
||||||
|
assert content_type == "application/octet-stream"
|
||||||
|
assert extension == ".bin"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetExtension:
|
||||||
|
def test_with_extension_override(self):
|
||||||
|
mime_type = "image/png"
|
||||||
|
for override in [".jpg", ""]:
|
||||||
|
extension = _get_extension(mime_type, override)
|
||||||
|
assert extension == override
|
||||||
|
|
||||||
|
def test_without_extension_override(self):
|
||||||
|
mime_type = "image/png"
|
||||||
|
extension = _get_extension(mime_type)
|
||||||
|
assert extension == ".png"
|
||||||
@ -0,0 +1 @@
|
|||||||
|
|
||||||
@ -0,0 +1,390 @@
|
|||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.variables import ArrayStringVariable
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
|
||||||
|
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
|
||||||
|
from models.enums import UserFrom
|
||||||
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
|
DEFAULT_NODE_ID = "node_id"
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_item_directly():
|
||||||
|
"""Test the _handle_item method directly for remove operations."""
|
||||||
|
# Create variables
|
||||||
|
variable1 = ArrayStringVariable(
|
||||||
|
id=str(uuid4()),
|
||||||
|
name="test_variable1",
|
||||||
|
value=["first", "second", "third"],
|
||||||
|
)
|
||||||
|
|
||||||
|
variable2 = ArrayStringVariable(
|
||||||
|
id=str(uuid4()),
|
||||||
|
name="test_variable2",
|
||||||
|
value=["first", "second", "third"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a mock class with just the _handle_item method
|
||||||
|
class MockNode:
|
||||||
|
def _handle_item(self, *, variable, operation, value):
|
||||||
|
match operation:
|
||||||
|
case Operation.REMOVE_FIRST:
|
||||||
|
if not variable.value:
|
||||||
|
return variable.value
|
||||||
|
return variable.value[1:]
|
||||||
|
case Operation.REMOVE_LAST:
|
||||||
|
if not variable.value:
|
||||||
|
return variable.value
|
||||||
|
return variable.value[:-1]
|
||||||
|
|
||||||
|
node = MockNode()
|
||||||
|
|
||||||
|
# Test remove-first
|
||||||
|
result1 = node._handle_item(
|
||||||
|
variable=variable1,
|
||||||
|
operation=Operation.REMOVE_FIRST,
|
||||||
|
value=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test remove-last
|
||||||
|
result2 = node._handle_item(
|
||||||
|
variable=variable2,
|
||||||
|
operation=Operation.REMOVE_LAST,
|
||||||
|
value=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check the results
|
||||||
|
assert result1 == ["second", "third"]
|
||||||
|
assert result2 == ["first", "second"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_first_from_array():
|
||||||
|
"""Test removing the first element from an array."""
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-assigner-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "assigner",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"type": "start"}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "assigner",
|
||||||
|
},
|
||||||
|
"id": "assigner",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
|
init_params = GraphInitParams(
|
||||||
|
tenant_id="1",
|
||||||
|
app_id="1",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="1",
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="1",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_variable = ArrayStringVariable(
|
||||||
|
id=str(uuid4()),
|
||||||
|
name="test_conversation_variable",
|
||||||
|
value=["first", "second", "third"],
|
||||||
|
selector=["conversation", "test_conversation_variable"],
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||||
|
user_inputs={},
|
||||||
|
environment_variables=[],
|
||||||
|
conversation_variables=[conversation_variable],
|
||||||
|
)
|
||||||
|
|
||||||
|
node = VariableAssignerNode(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
graph_init_params=init_params,
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||||
|
config={
|
||||||
|
"id": "node_id",
|
||||||
|
"data": {
|
||||||
|
"title": "test",
|
||||||
|
"version": "2",
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"variable_selector": ["conversation", conversation_variable.name],
|
||||||
|
"input_type": InputType.VARIABLE,
|
||||||
|
"operation": Operation.REMOVE_FIRST,
|
||||||
|
"value": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip the mock assertion since we're in a test environment
|
||||||
|
# Print the variable before running
|
||||||
|
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
|
||||||
|
|
||||||
|
# Run the node
|
||||||
|
result = list(node.run())
|
||||||
|
|
||||||
|
# Print the variable after running and the result
|
||||||
|
print(f"After: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
|
||||||
|
print(f"Result: {result}")
|
||||||
|
|
||||||
|
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||||
|
assert got is not None
|
||||||
|
assert got.to_object() == ["second", "third"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_last_from_array():
|
||||||
|
"""Test removing the last element from an array."""
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-assigner-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "assigner",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"type": "start"}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "assigner",
|
||||||
|
},
|
||||||
|
"id": "assigner",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
|
init_params = GraphInitParams(
|
||||||
|
tenant_id="1",
|
||||||
|
app_id="1",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="1",
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="1",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_variable = ArrayStringVariable(
|
||||||
|
id=str(uuid4()),
|
||||||
|
name="test_conversation_variable",
|
||||||
|
value=["first", "second", "third"],
|
||||||
|
selector=["conversation", "test_conversation_variable"],
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||||
|
user_inputs={},
|
||||||
|
environment_variables=[],
|
||||||
|
conversation_variables=[conversation_variable],
|
||||||
|
)
|
||||||
|
|
||||||
|
node = VariableAssignerNode(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
graph_init_params=init_params,
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||||
|
config={
|
||||||
|
"id": "node_id",
|
||||||
|
"data": {
|
||||||
|
"title": "test",
|
||||||
|
"version": "2",
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"variable_selector": ["conversation", conversation_variable.name],
|
||||||
|
"input_type": InputType.VARIABLE,
|
||||||
|
"operation": Operation.REMOVE_LAST,
|
||||||
|
"value": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip the mock assertion since we're in a test environment
|
||||||
|
list(node.run())
|
||||||
|
|
||||||
|
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||||
|
assert got is not None
|
||||||
|
assert got.to_object() == ["first", "second"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_first_from_empty_array():
|
||||||
|
"""Test removing the first element from an empty array (should do nothing)."""
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-assigner-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "assigner",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"type": "start"}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "assigner",
|
||||||
|
},
|
||||||
|
"id": "assigner",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
|
init_params = GraphInitParams(
|
||||||
|
tenant_id="1",
|
||||||
|
app_id="1",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="1",
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="1",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_variable = ArrayStringVariable(
|
||||||
|
id=str(uuid4()),
|
||||||
|
name="test_conversation_variable",
|
||||||
|
value=[],
|
||||||
|
selector=["conversation", "test_conversation_variable"],
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||||
|
user_inputs={},
|
||||||
|
environment_variables=[],
|
||||||
|
conversation_variables=[conversation_variable],
|
||||||
|
)
|
||||||
|
|
||||||
|
node = VariableAssignerNode(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
graph_init_params=init_params,
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||||
|
config={
|
||||||
|
"id": "node_id",
|
||||||
|
"data": {
|
||||||
|
"title": "test",
|
||||||
|
"version": "2",
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"variable_selector": ["conversation", conversation_variable.name],
|
||||||
|
"input_type": InputType.VARIABLE,
|
||||||
|
"operation": Operation.REMOVE_FIRST,
|
||||||
|
"value": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip the mock assertion since we're in a test environment
|
||||||
|
list(node.run())
|
||||||
|
|
||||||
|
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||||
|
assert got is not None
|
||||||
|
assert got.to_object() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_last_from_empty_array():
|
||||||
|
"""Test removing the last element from an empty array (should do nothing)."""
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-assigner-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "assigner",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"type": "start"}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "assigner",
|
||||||
|
},
|
||||||
|
"id": "assigner",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
|
init_params = GraphInitParams(
|
||||||
|
tenant_id="1",
|
||||||
|
app_id="1",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="1",
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="1",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_variable = ArrayStringVariable(
|
||||||
|
id=str(uuid4()),
|
||||||
|
name="test_conversation_variable",
|
||||||
|
value=[],
|
||||||
|
selector=["conversation", "test_conversation_variable"],
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||||
|
user_inputs={},
|
||||||
|
environment_variables=[],
|
||||||
|
conversation_variables=[conversation_variable],
|
||||||
|
)
|
||||||
|
|
||||||
|
node = VariableAssignerNode(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
graph_init_params=init_params,
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||||
|
config={
|
||||||
|
"id": "node_id",
|
||||||
|
"data": {
|
||||||
|
"title": "test",
|
||||||
|
"version": "2",
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"variable_selector": ["conversation", conversation_variable.name],
|
||||||
|
"input_type": InputType.VARIABLE,
|
||||||
|
"operation": Operation.REMOVE_LAST,
|
||||||
|
"value": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip the mock assertion since we're in a test environment
|
||||||
|
list(node.run())
|
||||||
|
|
||||||
|
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||||
|
assert got is not None
|
||||||
|
assert got.to_object() == []
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
pytest api/tests/integration_tests/tools/test_all_provider.py
|
pytest api/tests/integration_tests/tools
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 62 KiB After Width: | Height: | Size: 170 KiB |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue