Merge branch 'main' into e-300
commit
0301bd3ac1
@ -0,0 +1,22 @@
|
||||
{
|
||||
"Verbose": false,
|
||||
"Debug": false,
|
||||
"IgnoreDefaults": false,
|
||||
"SpacesAfterTabs": false,
|
||||
"NoColor": false,
|
||||
"Exclude": [
|
||||
"^web/public/vs/",
|
||||
"^web/public/pdf.worker.min.mjs$",
|
||||
"web/app/components/base/icons/src/vender/"
|
||||
],
|
||||
"AllowedContentTypes": [],
|
||||
"PassedFiles": [],
|
||||
"Disable": {
|
||||
"EndOfLine": false,
|
||||
"Indentation": false,
|
||||
"IndentSize": true,
|
||||
"InsertFinalNewline": false,
|
||||
"TrimTrailingWhitespace": false,
|
||||
"MaxLineLength": false
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,7 @@
|
||||
# The two constants below should keep in sync.
|
||||
# Default content type for files which have no explicit content type.
|
||||
|
||||
DEFAULT_MIME_TYPE = "application/octet-stream"
|
||||
# Default file extension for files which have no explicit content type, should
|
||||
# correspond to the `DEFAULT_MIME_TYPE` above.
|
||||
DEFAULT_EXTENSION = ".bin"
|
||||
@ -1,12 +1,19 @@
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
|
||||
|
||||
|
||||
class ToolFileParser:
|
||||
@staticmethod
|
||||
def get_tool_file_manager() -> "ToolFileManager":
|
||||
return cast("ToolFileManager", tool_file_manager["manager"])
|
||||
assert _tool_file_manager_factory is not None
|
||||
return _tool_file_manager_factory()
|
||||
|
||||
|
||||
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
|
||||
global _tool_file_manager_factory
|
||||
_tool_file_manager_factory = factory
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -0,0 +1,160 @@
|
||||
import mimetypes
|
||||
import typing as tp
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
"""LLMFileSaver is responsible for save multimodal output returned by
|
||||
LLM.
|
||||
"""
|
||||
|
||||
def save_binary_string(
|
||||
self,
|
||||
data: bytes,
|
||||
mime_type: str,
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
"""save_binary_string saves the inline file data returned by LLM.
|
||||
|
||||
Currently (2025-04-30), only some of Google Gemini models will return
|
||||
multimodal output as inline data.
|
||||
|
||||
:param data: the contents of the file
|
||||
:param mime_type: the media type of the file, specified by rfc6838
|
||||
(https://datatracker.ietf.org/doc/html/rfc6838)
|
||||
:param file_type: The file type of the inline file.
|
||||
:param extension_override: Override the auto-detected file extension while saving this file.
|
||||
|
||||
The default value is `None`, which means do not override the file extension and guessing it
|
||||
from the `mime_type` attribute while saving the file.
|
||||
|
||||
Setting it to values other than `None` means override the file's extension, and
|
||||
will bypass the extension guessing saving the file.
|
||||
|
||||
Specially, setting it to empty string (`""`) will leave the file extension empty.
|
||||
|
||||
When it is not `None` or empty string (`""`), it should be a string beginning with a
|
||||
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
|
||||
and `tar.gz` are not.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
"""save_remote_url saves the file from a remote url returned by LLM.
|
||||
|
||||
Currently (2025-04-30), no model returns multimodel output as a url.
|
||||
|
||||
:param url: the url of the file.
|
||||
:param file_type: the file type of the file, check `FileType` enum for reference.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
||||
|
||||
class FileSaverImpl(LLMFileSaver):
|
||||
_engine_factory: EngineFactory
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
|
||||
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||
if engine_factory is None:
|
||||
|
||||
def _factory():
|
||||
return global_db.engine
|
||||
|
||||
engine_factory = _factory
|
||||
self._engine_factory = engine_factory
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _get_tool_file_manager(self):
|
||||
return ToolFileManager(engine=self._engine_factory())
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
http_response = ssrf_proxy.get(url)
|
||||
http_response.raise_for_status()
|
||||
data = http_response.content
|
||||
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
|
||||
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
|
||||
|
||||
def save_binary_string(
|
||||
self,
|
||||
data: bytes,
|
||||
mime_type: str,
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
tool_file_manager = self._get_tool_file_manager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
# TODO(QuantumGhost): what is conversation id?
|
||||
conversation_id=None,
|
||||
file_binary=data,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
extension_override = _validate_extension_override(extension_override)
|
||||
extension = _get_extension(mime_type, extension_override)
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
return File(
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=tool_file.name,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(data),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
# TODO(QuantumGhost): how should I set the following key?
|
||||
# What's the difference between `remote_url` and `url`?
|
||||
# What's the purpose of `storage_key` and `dify_model_identity`?
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
|
||||
"""get_extension return the extension of file.
|
||||
|
||||
If the `extension_override` parameter is set, this function should honor it and
|
||||
return its value.
|
||||
"""
|
||||
if extension_override is not None:
|
||||
return extension_override
|
||||
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
|
||||
|
||||
|
||||
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
|
||||
"""_extract_content_type_and_extension tries to
|
||||
guess content type of file from url and `Content-Type` header in response.
|
||||
"""
|
||||
if content_type_header:
|
||||
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
|
||||
return content_type_header, extension
|
||||
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
|
||||
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
|
||||
return content_type, extension
|
||||
|
||||
|
||||
def _validate_extension_override(extension_override: str | None) -> str | None:
|
||||
# `extension_override` is allow to be `None or `""`.
|
||||
if extension_override is None:
|
||||
return None
|
||||
if extension_override == "":
|
||||
return ""
|
||||
if not extension_override.startswith("."):
|
||||
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
|
||||
return extension_override
|
||||
@ -0,0 +1,192 @@
|
||||
import uuid
|
||||
from typing import NamedTuple
|
||||
from unittest import mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.file import FileTransferMethod, FileType, models
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools import signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.nodes.llm.file_saver import (
|
||||
FileSaverImpl,
|
||||
_extract_content_type_and_extension,
|
||||
_get_extension,
|
||||
_validate_extension_override,
|
||||
)
|
||||
from models import ToolFile
|
||||
|
||||
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
|
||||
def _gen_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestFileSaverImpl:
|
||||
def test_save_binary_string(self, monkeypatch):
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
file_type = FileType.IMAGE
|
||||
mime_type = "image/png"
|
||||
mock_signed_url = "https://example.com/image.png"
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||
mocked_engine = mock.MagicMock(spec=Engine)
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
|
||||
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
|
||||
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
|
||||
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
|
||||
mocked_sign_file.return_value = mock_signed_url
|
||||
|
||||
storage_file_manager = FileSaverImpl(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
engine_factory=mocked_engine,
|
||||
)
|
||||
|
||||
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||
assert file.tenant_id == tenant_id
|
||||
assert file.type == file_type
|
||||
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||
assert file.extension == ".png"
|
||||
assert file.mime_type == mime_type
|
||||
assert file.size == len(_PNG_DATA)
|
||||
assert file.related_id == mock_tool_file.id
|
||||
|
||||
assert file.generate_url() == mock_signed_url
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=_PNG_DATA,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
||||
|
||||
def test_save_remote_url_request_failed(self, monkeypatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=401,
|
||||
request=mock_request,
|
||||
)
|
||||
file_saver = FileSaverImpl(
|
||||
user_id=_gen_id(),
|
||||
tenant_id=_gen_id(),
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc:
|
||||
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
def test_save_remote_url_success(self, monkeypatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mime_type = "image/png"
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=200,
|
||||
content=b"test-data",
|
||||
headers={"Content-Type": mime_type},
|
||||
request=mock_request,
|
||||
)
|
||||
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
|
||||
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
|
||||
|
||||
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_save_binary_string.assert_called_once_with(
|
||||
mock_response.content,
|
||||
mime_type,
|
||||
FileType.IMAGE,
|
||||
extension_override=".png",
|
||||
)
|
||||
assert file == mock_tool_file
|
||||
|
||||
|
||||
def test_validate_extension_override():
|
||||
class TestCase(NamedTuple):
|
||||
extension_override: str | None
|
||||
expected: str | None
|
||||
|
||||
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
|
||||
|
||||
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
|
||||
assert valid_ext_override == _validate_extension_override(valid_ext_override)
|
||||
|
||||
for invalid_ext_override in ["png", "tar.gz"]:
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_validate_extension_override(invalid_ext_override)
|
||||
|
||||
|
||||
class TestExtractContentTypeAndExtension:
|
||||
def test_with_both_content_type_and_extension(self):
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_url_with_file_extension(self):
|
||||
for content_type in [None, ""]:
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_response_with_content_type(self):
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_no_content_type_and_no_extension(self):
|
||||
for content_type in [None, ""]:
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
|
||||
assert content_type == "application/octet-stream"
|
||||
assert extension == ".bin"
|
||||
|
||||
|
||||
class TestGetExtension:
|
||||
def test_with_extension_override(self):
|
||||
mime_type = "image/png"
|
||||
for override in [".jpg", ""]:
|
||||
extension = _get_extension(mime_type, override)
|
||||
assert extension == override
|
||||
|
||||
def test_without_extension_override(self):
|
||||
mime_type = "image/png"
|
||||
extension = _get_extension(mime_type)
|
||||
assert extension == ".png"
|
||||
@ -0,0 +1 @@
|
||||
|
||||
@ -0,0 +1,390 @@
|
||||
import time
|
||||
import uuid
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
DEFAULT_NODE_ID = "node_id"
|
||||
|
||||
|
||||
def test_handle_item_directly():
|
||||
"""Test the _handle_item method directly for remove operations."""
|
||||
# Create variables
|
||||
variable1 = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name="test_variable1",
|
||||
value=["first", "second", "third"],
|
||||
)
|
||||
|
||||
variable2 = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name="test_variable2",
|
||||
value=["first", "second", "third"],
|
||||
)
|
||||
|
||||
# Create a mock class with just the _handle_item method
|
||||
class MockNode:
|
||||
def _handle_item(self, *, variable, operation, value):
|
||||
match operation:
|
||||
case Operation.REMOVE_FIRST:
|
||||
if not variable.value:
|
||||
return variable.value
|
||||
return variable.value[1:]
|
||||
case Operation.REMOVE_LAST:
|
||||
if not variable.value:
|
||||
return variable.value
|
||||
return variable.value[:-1]
|
||||
|
||||
node = MockNode()
|
||||
|
||||
# Test remove-first
|
||||
result1 = node._handle_item(
|
||||
variable=variable1,
|
||||
operation=Operation.REMOVE_FIRST,
|
||||
value=None,
|
||||
)
|
||||
|
||||
# Test remove-last
|
||||
result2 = node._handle_item(
|
||||
variable=variable2,
|
||||
operation=Operation.REMOVE_LAST,
|
||||
value=None,
|
||||
)
|
||||
|
||||
# Check the results
|
||||
assert result1 == ["second", "third"]
|
||||
assert result2 == ["first", "second"]
|
||||
|
||||
|
||||
def test_remove_first_from_array():
|
||||
"""Test removing the first element from an array."""
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-assigner-target",
|
||||
"source": "start",
|
||||
"target": "assigner",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name="test_conversation_variable",
|
||||
value=["first", "second", "third"],
|
||||
selector=["conversation", "test_conversation_variable"],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_FIRST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
# Print the variable before running
|
||||
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
|
||||
|
||||
# Run the node
|
||||
result = list(node.run())
|
||||
|
||||
# Print the variable after running and the result
|
||||
print(f"After: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
|
||||
print(f"Result: {result}")
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == ["second", "third"]
|
||||
|
||||
|
||||
def test_remove_last_from_array():
|
||||
"""Test removing the last element from an array."""
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-assigner-target",
|
||||
"source": "start",
|
||||
"target": "assigner",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name="test_conversation_variable",
|
||||
value=["first", "second", "third"],
|
||||
selector=["conversation", "test_conversation_variable"],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_LAST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == ["first", "second"]
|
||||
|
||||
|
||||
def test_remove_first_from_empty_array():
|
||||
"""Test removing the first element from an empty array (should do nothing)."""
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-assigner-target",
|
||||
"source": "start",
|
||||
"target": "assigner",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name="test_conversation_variable",
|
||||
value=[],
|
||||
selector=["conversation", "test_conversation_variable"],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_FIRST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == []
|
||||
|
||||
|
||||
def test_remove_last_from_empty_array():
|
||||
"""Test removing the last element from an empty array (should do nothing)."""
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-assigner-target",
|
||||
"source": "start",
|
||||
"target": "assigner",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name="test_conversation_variable",
|
||||
value=[],
|
||||
selector=["conversation", "test_conversation_variable"],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_LAST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == []
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
pytest api/tests/integration_tests/tools/test_all_provider.py
|
||||
pytest api/tests/integration_tests/tools
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 62 KiB After Width: | Height: | Size: 170 KiB |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue