diff --git a/api/templates/without-brand/invite_member_mail_template_zh-CN.html b/api/templates/without-brand/invite_member_mail_template_zh-CN.html
index d4f80c66f8..fd2d6b873f 100644
--- a/api/templates/without-brand/invite_member_mail_template_zh-CN.html
+++ b/api/templates/without-brand/invite_member_mail_template_zh-CN.html
@@ -1,69 +1,91 @@
-
+
-
-
-
尊敬的 {{ to }},
-
{{ inviter_name }} 现邀请您加入我们在 {{application_title}} 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 {{application_title}} 上,您可以探索、创造和合作,构建和运营 AI 应用。
-
点击下方按钮即可登录 {{application_title}} 并且加入空间。
-
在此登录
-
-
+
+
+
+
尊敬的 {{ to }},
+
{{ inviter_name }} 现邀请您加入我们在 {{application_title}} 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 {{application_title}} 上,您可以探索、创造和合作,构建和运营 AI 应用。
+
点击下方按钮即可登录 {{application_title}} 并且加入空间。
+
在此登录
+
此致,
+
{{application_title}} 团队
+
diff --git a/api/templates/without-brand/transfer_workspace_new_owner_notify_template_en-US.html b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_en-US.html
new file mode 100644
index 0000000000..a5758a2184
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_en-US.html
@@ -0,0 +1,89 @@
+
+
+
+
+
+
+
+
+
+
+
You are now the owner of {{WorkspaceName}}
+
+
You have been assigned as the new owner of the workspace "{{WorkspaceName}}".
+
As the new owner, you now have full administrative privileges for this workspace.
+
If you have any questions, please contact support@dify.ai.
+
+
+
+
+
+
diff --git a/api/templates/without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html
new file mode 100644
index 0000000000..53bab92552
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html
@@ -0,0 +1,89 @@
+
+
+
+
+
+
+
+
+
+
+
您现在是 {{WorkspaceName}} 的所有者
+
+
您已被分配为工作空间“{{WorkspaceName}}”的新所有者。
+
作为新所有者,您现在对该工作空间拥有完全的管理权限。
+
如果您有任何问题,请联系support@dify.ai。
+
+
+
+
+
+
diff --git a/api/templates/without-brand/transfer_workspace_old_owner_notify_template_en-US.html b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_en-US.html
new file mode 100644
index 0000000000..3e7faeb01e
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_en-US.html
@@ -0,0 +1,119 @@
+
+
+
+
+
+
+
+
+
+
+
Workspace ownership has been transferred
+
+
You have successfully transferred ownership of the workspace "{{WorkspaceName}}" to {{NewOwnerEmail}}.
+
You no longer have owner privileges for this workspace. Your access level has been changed to Admin.
+
If you did not initiate this transfer or have concerns about this change, please contact support@dify.ai immediately.
+
+
+
+
+
+
diff --git a/api/templates/without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html
new file mode 100644
index 0000000000..31e3c23140
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html
@@ -0,0 +1,119 @@
+
+
+
+
+
+
+
+
+
+
+
工作区所有权已转移
+
+
您已成功将工作空间“{{WorkspaceName}}”的所有权转移给{{NewOwnerEmail}}。
+
您不再拥有此工作空间的拥有者权限。您的访问级别已更改为管理员。
+
如果您没有发起此转移或对此变更有任何疑问,请立即联系support@dify.ai。
+
+
+
+
+
+
diff --git a/api/templates/without-brand/transfer_workspace_owner_confirm_template_en-US.html b/api/templates/without-brand/transfer_workspace_owner_confirm_template_en-US.html
new file mode 100644
index 0000000000..11ce275641
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_owner_confirm_template_en-US.html
@@ -0,0 +1,150 @@
+
+
+
+
+
+
+
+
+
+
+
Verify Your Request to Transfer Workspace Ownership
+
+
We received a request to transfer ownership of your workspace “{{WorkspaceName}}”.
+
To confirm this action, please use the verification code below.
+
This code will only be valid for the next 5 minutes:
+
+
+ {{code}}
+
+
Please note:
+
+ The ownership transfer will take effect immediately once confirmed and cannot be undone.
+ You’ll become an admin member, and the new owner will have full control of the workspace.
+
+
If you didn’t make this request, please ignore this email or contact support immediately.
+
+
+
+
+
diff --git a/api/templates/without-brand/transfer_workspace_owner_confirm_template_zh-CN.html b/api/templates/without-brand/transfer_workspace_owner_confirm_template_zh-CN.html
new file mode 100644
index 0000000000..36b9a24a89
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_owner_confirm_template_zh-CN.html
@@ -0,0 +1,150 @@
+
+
+
+
+
+
+
+
+
+
+
验证您的工作空间所有权转移请求
+
+
我们收到了将您的工作空间“{{WorkspaceName}}”的所有权转移的请求。
+
为了确认此操作,请使用以下验证码。
+
此验证码仅在5分钟内有效:
+
+
+ {{code}}
+
+
请注意:
+
+ 所有权转移一旦确认将立即生效且无法撤销。
+ 您将成为管理员成员,新的所有者将拥有工作空间的完全控制权。
+
+
如果您没有发起此请求,请忽略此邮件或立即联系客服。
+
+
+
+
+
diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example
index 4046096c27..2e98dec964 100644
--- a/api/tests/integration_tests/.env.example
+++ b/api/tests/integration_tests/.env.example
@@ -203,6 +203,8 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
+CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
+OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
CREATE_TIDB_SERVICE_JOB_ENABLED=false
diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py
index 7c48d84d69..330ebfd54a 100644
--- a/api/tests/integration_tests/workflow/nodes/__mock/model.py
+++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py
@@ -15,7 +15,7 @@ def get_mocked_fetch_model_config(
mode: str,
credentials: dict,
):
- model_provider_factory = ModelProviderFactory(tenant_id="test_tenant")
+ model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py
index 13d78c2d83..707b28e6d8 100644
--- a/api/tests/integration_tests/workflow/nodes/test_code.py
+++ b/api/tests/integration_tests/workflow/nodes/test_code.py
@@ -9,12 +9,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.entities import CodeNodeData
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@@ -50,7 +50,7 @@ def init_code_node(code_config: dict):
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -66,6 +66,10 @@ def init_code_node(code_config: dict):
config=code_config,
)
+ # Initialize node data
+ if "data" in code_config:
+ node.init_node_data(code_config["data"])
+
return node
@@ -234,10 +238,10 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
- node.node_data = cast(CodeNodeData, node.node_data)
+ node._node_data = cast(CodeNodeData, node._node_data)
# validate
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -250,7 +254,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -263,7 +267,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -276,7 +280,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
def test_execute_code_output_object_list():
@@ -330,10 +334,10 @@ def test_execute_code_output_object_list():
]
}
- node.node_data = cast(CodeNodeData, node.node_data)
+ node._node_data = cast(CodeNodeData, node._node_data)
# validate
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -353,4 +357,36 @@ def test_execute_code_output_object_list():
# validate
with pytest.raises(ValueError):
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
+
+
+def test_execute_code_scientific_notation():
+ code = """
+ def main() -> dict:
+ return {
+ "result": -8.0E-5
+ }
+ """
+ code = "\n".join([line[4:] for line in code.split("\n")])
+
+ code_config = {
+ "id": "code",
+ "data": {
+ "outputs": {
+ "result": {
+ "type": "number",
+ },
+ },
+ "title": "123",
+ "variables": [],
+ "answer": "123",
+ "code_language": "python3",
+ "code": code,
+ },
+ }
+
+ node = init_code_node(code_config)
+ # execute node
+ result = node._run()
+ assert isinstance(result, NodeRunResult)
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py
index 1ab0cc2451..d7856129a3 100644
--- a/api/tests/integration_tests/workflow/nodes/test_http.py
+++ b/api/tests/integration_tests/workflow/nodes/test_http.py
@@ -6,11 +6,11 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.http_request.node import HttpRequestNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
@@ -44,7 +44,7 @@ def init_http_node(config: dict):
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -52,7 +52,7 @@ def init_http_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
- return HttpRequestNode(
+ node = HttpRequestNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
@@ -60,6 +60,12 @@ def init_http_node(config: dict):
config=config,
)
+ # Initialize node data
+ if "data" in config:
+ node.init_node_data(config["data"])
+
+ return node
+
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_get(setup_http_mock):
diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py
index 638323f850..a14791bc67 100644
--- a/api/tests/integration_tests/workflow/nodes/test_llm.py
+++ b/api/tests/integration_tests/workflow/nodes/test_llm.py
@@ -2,30 +2,23 @@ import json
import time
import uuid
from collections.abc import Generator
-from decimal import Decimal
from unittest.mock import MagicMock, patch
-import pytest
-
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output
-from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
-from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
+from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
-from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
-from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
def init_llm_node(config: dict) -> LLMNode:
@@ -62,12 +55,14 @@ def init_llm_node(config: dict) -> LLMNode:
# construct variable pool
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "what's the weather today?",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa",
+ app_id=app_id,
+ workflow_id=workflow_id,
+ files=[],
+ query="what's the weather today?",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -82,10 +77,14 @@ def init_llm_node(config: dict) -> LLMNode:
config=config,
)
+ # Initialize node data
+ if "data" in config:
+ node.init_node_data(config["data"])
+
return node
-def test_execute_llm(flask_req_ctx):
+def test_execute_llm():
node = init_llm_node(
config={
"id": "llm",
@@ -93,7 +92,7 @@ def test_execute_llm(flask_req_ctx):
"title": "123",
"type": "llm",
"model": {
- "provider": "langgenius/openai/openai",
+ "provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
@@ -112,53 +111,62 @@ def test_execute_llm(flask_req_ctx):
},
)
- # Create a proper LLM result with real entities
- mock_usage = LLMUsage(
- prompt_tokens=30,
- prompt_unit_price=Decimal("0.001"),
- prompt_price_unit=Decimal("1000"),
- prompt_price=Decimal("0.00003"),
- completion_tokens=20,
- completion_unit_price=Decimal("0.002"),
- completion_price_unit=Decimal("1000"),
- completion_price=Decimal("0.00004"),
- total_tokens=50,
- total_price=Decimal("0.00007"),
- currency="USD",
- latency=0.5,
- )
-
- mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
-
- mock_llm_result = LLMResult(
- model="gpt-3.5-turbo",
- prompt_messages=[],
- message=mock_message,
- usage=mock_usage,
- )
-
- # Create a simple mock model instance that doesn't call real providers
- mock_model_instance = MagicMock()
- mock_model_instance.invoke_llm.return_value = mock_llm_result
+ db.session.close = MagicMock()
- # Create a simple mock model config with required attributes
- mock_model_config = MagicMock()
- mock_model_config.mode = "chat"
- mock_model_config.provider = "langgenius/openai/openai"
- mock_model_config.model = "gpt-3.5-turbo"
- mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
+ # Mock the _fetch_model_config to avoid database calls
+ def mock_fetch_model_config(**_kwargs):
+ from decimal import Decimal
+ from unittest.mock import MagicMock
+
+ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+ from core.model_runtime.entities.message_entities import AssistantPromptMessage
+
+ # Create mock model instance
+ mock_model_instance = MagicMock()
+ mock_usage = LLMUsage(
+ prompt_tokens=30,
+ prompt_unit_price=Decimal("0.001"),
+ prompt_price_unit=Decimal(1000),
+ prompt_price=Decimal("0.00003"),
+ completion_tokens=20,
+ completion_unit_price=Decimal("0.002"),
+ completion_price_unit=Decimal(1000),
+ completion_price=Decimal("0.00004"),
+ total_tokens=50,
+ total_price=Decimal("0.00007"),
+ currency="USD",
+ latency=0.5,
+ )
+ mock_message = AssistantPromptMessage(content="Test response from mock")
+ mock_llm_result = LLMResult(
+ model="gpt-3.5-turbo",
+ prompt_messages=[],
+ message=mock_message,
+ usage=mock_usage,
+ )
+ mock_model_instance.invoke_llm.return_value = mock_llm_result
+
+ # Create mock model config
+ mock_model_config = MagicMock()
+ mock_model_config.mode = "chat"
+ mock_model_config.provider = "openai"
+ mock_model_config.model = "gpt-3.5-turbo"
+ mock_model_config.parameters = {}
- # Mock the _fetch_model_config method
- def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
- # Also mock ModelManager.get_model_instance to avoid database calls
- def mock_get_model_instance(_self, **kwargs):
- return mock_model_instance
+ # Mock fetch_prompt_messages to avoid database calls
+ def mock_fetch_prompt_messages_1(**_kwargs):
+ from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
+
+ return [
+ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
+ UserPromptMessage(content="what's the weather today?"),
+ ], []
with (
- patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
- patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
+ patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
+ patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
):
# execute node
result = node._run()
@@ -166,6 +174,9 @@ def test_execute_llm(flask_req_ctx):
for item in result:
if isinstance(item, RunCompletedEvent):
+ if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
+ print(f"Error: {item.run_result.error}")
+ print(f"Error type: {item.run_result.error_type}")
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
@@ -173,8 +184,7 @@ def test_execute_llm(flask_req_ctx):
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
-@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
-def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
+def test_execute_llm_with_jinja2():
"""
Test execute LLM node with jinja2
"""
@@ -215,53 +225,60 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
# Mock db.session.close()
db.session.close = MagicMock()
- # Create a proper LLM result with real entities
- mock_usage = LLMUsage(
- prompt_tokens=30,
- prompt_unit_price=Decimal("0.001"),
- prompt_price_unit=Decimal("1000"),
- prompt_price=Decimal("0.00003"),
- completion_tokens=20,
- completion_unit_price=Decimal("0.002"),
- completion_price_unit=Decimal("1000"),
- completion_price=Decimal("0.00004"),
- total_tokens=50,
- total_price=Decimal("0.00007"),
- currency="USD",
- latency=0.5,
- )
-
- mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
-
- mock_llm_result = LLMResult(
- model="gpt-3.5-turbo",
- prompt_messages=[],
- message=mock_message,
- usage=mock_usage,
- )
-
- # Create a simple mock model instance that doesn't call real providers
- mock_model_instance = MagicMock()
- mock_model_instance.invoke_llm.return_value = mock_llm_result
-
- # Create a simple mock model config with required attributes
- mock_model_config = MagicMock()
- mock_model_config.mode = "chat"
- mock_model_config.provider = "openai"
- mock_model_config.model = "gpt-3.5-turbo"
- mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
-
# Mock the _fetch_model_config method
- def mock_fetch_model_config_func(_node_data_model):
+ def mock_fetch_model_config(**_kwargs):
+ from decimal import Decimal
+ from unittest.mock import MagicMock
+
+ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+ from core.model_runtime.entities.message_entities import AssistantPromptMessage
+
+ # Create mock model instance
+ mock_model_instance = MagicMock()
+ mock_usage = LLMUsage(
+ prompt_tokens=30,
+ prompt_unit_price=Decimal("0.001"),
+ prompt_price_unit=Decimal(1000),
+ prompt_price=Decimal("0.00003"),
+ completion_tokens=20,
+ completion_unit_price=Decimal("0.002"),
+ completion_price_unit=Decimal(1000),
+ completion_price=Decimal("0.00004"),
+ total_tokens=50,
+ total_price=Decimal("0.00007"),
+ currency="USD",
+ latency=0.5,
+ )
+ mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
+ mock_llm_result = LLMResult(
+ model="gpt-3.5-turbo",
+ prompt_messages=[],
+ message=mock_message,
+ usage=mock_usage,
+ )
+ mock_model_instance.invoke_llm.return_value = mock_llm_result
+
+ # Create mock model config
+ mock_model_config = MagicMock()
+ mock_model_config.mode = "chat"
+ mock_model_config.provider = "openai"
+ mock_model_config.model = "gpt-3.5-turbo"
+ mock_model_config.parameters = {}
+
return mock_model_instance, mock_model_config
- # Also mock ModelManager.get_model_instance to avoid database calls
- def mock_get_model_instance(_self, **kwargs):
- return mock_model_instance
+ # Mock fetch_prompt_messages to avoid database calls
+ def mock_fetch_prompt_messages_2(**_kwargs):
+ from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
+
+ return [
+ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
+ UserPromptMessage(content="what's the weather today?"),
+ ], []
with (
- patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
- patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
+ patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
+ patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
):
# execute node
result = node._run()
diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
index 0df8e8b146..edd70193a8 100644
--- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
+++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
@@ -8,11 +8,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
+from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
@@ -64,12 +64,9 @@ def init_parameter_extractor_node(config: dict):
# construct variable pool
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "what's the weather in SF",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
+ ),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -77,13 +74,15 @@ def init_parameter_extractor_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
- return ParameterExtractorNode(
+ node = ParameterExtractorNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
+ node.init_node_data(config.get("data", {}))
+ return node
def test_function_calling_parameter_extractor(setup_model_mock):
diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
index a5f2677a59..f71a5ee140 100644
--- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py
+++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
@@ -6,11 +6,11 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@@ -61,7 +61,7 @@ def test_execute_code(setup_code_executor_mock):
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -76,6 +76,7 @@ def test_execute_code(setup_code_executor_mock):
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
+ node.init_node_data(config.get("data", {}))
# execute node
result = node._run()
diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py
index 039beedafe..8476c1f874 100644
--- a/api/tests/integration_tests/workflow/nodes/test_tool.py
+++ b/api/tests/integration_tests/workflow/nodes/test_tool.py
@@ -6,12 +6,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.tool.tool_node import ToolNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -44,19 +44,21 @@ def init_tool_node(config: dict):
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
- return ToolNode(
+ node = ToolNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
+ node.init_node_data(config.get("data", {}))
+ return node
def test_tool_variable_invoke():
diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py
index b70c8830ed..e9d4ee1935 100644
--- a/api/tests/unit_tests/configs/test_dify_config.py
+++ b/api/tests/unit_tests/configs/test_dify_config.py
@@ -88,6 +88,7 @@ def test_flask_configs(monkeypatch):
"pool_pre_ping": False,
"pool_recycle": 3600,
"pool_size": 30,
+ "pool_use_lifo": False,
}
assert config["CONSOLE_WEB_URL"] == "https://example.com"
diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py
new file mode 100644
index 0000000000..037c9f2745
--- /dev/null
+++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py
@@ -0,0 +1,496 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+from flask import Flask
+
+from controllers.console.auth.oauth import (
+ OAuthCallback,
+ OAuthLogin,
+ _generate_account,
+ _get_account_by_openid_or_email,
+ get_oauth_providers,
+)
+from libs.oauth import OAuthUserInfo
+from models.account import AccountStatus
+from services.errors.account import AccountNotFoundError
+
+
+class TestGetOAuthProviders:
+ @pytest.fixture
+ def app(self):
+ app = Flask(__name__)
+ app.config["TESTING"] = True
+ return app
+
+ @pytest.mark.parametrize(
+ ("github_config", "google_config", "expected_github", "expected_google"),
+ [
+ # Both providers configured
+ (
+ {"id": "github_id", "secret": "github_secret"},
+ {"id": "google_id", "secret": "google_secret"},
+ True,
+ True,
+ ),
+ # Only GitHub configured
+ ({"id": "github_id", "secret": "github_secret"}, {"id": None, "secret": None}, True, False),
+ # Only Google configured
+ ({"id": None, "secret": None}, {"id": "google_id", "secret": "google_secret"}, False, True),
+ # No providers configured
+ ({"id": None, "secret": None}, {"id": None, "secret": None}, False, False),
+ ],
+ )
+ @patch("controllers.console.auth.oauth.dify_config")
+ def test_should_configure_oauth_providers_correctly(
+ self, mock_config, app, github_config, google_config, expected_github, expected_google
+ ):
+ mock_config.GITHUB_CLIENT_ID = github_config["id"]
+ mock_config.GITHUB_CLIENT_SECRET = github_config["secret"]
+ mock_config.GOOGLE_CLIENT_ID = google_config["id"]
+ mock_config.GOOGLE_CLIENT_SECRET = google_config["secret"]
+ mock_config.CONSOLE_API_URL = "http://localhost"
+
+ with app.app_context():
+ providers = get_oauth_providers()
+
+ assert (providers["github"] is not None) == expected_github
+ assert (providers["google"] is not None) == expected_google
+
+
+class TestOAuthLogin:
+ @pytest.fixture
+ def resource(self):
+ return OAuthLogin()
+
+ @pytest.fixture
+ def app(self):
+ app = Flask(__name__)
+ app.config["TESTING"] = True
+ return app
+
+ @pytest.fixture
+ def mock_oauth_provider(self):
+ provider = MagicMock()
+ provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?..."
+ return provider
+
+ @pytest.mark.parametrize(
+ ("invite_token", "expected_token"),
+ [
+ (None, None),
+ ("test_invite_token", "test_invite_token"),
+ ("", None),
+ ],
+ )
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ @patch("controllers.console.auth.oauth.redirect")
+ def test_should_handle_oauth_login_with_various_tokens(
+ self,
+ mock_redirect,
+ mock_get_providers,
+ resource,
+ app,
+ mock_oauth_provider,
+ invite_token,
+ expected_token,
+ ):
+ mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
+
+ query_string = f"invite_token={invite_token}" if invite_token else ""
+ with app.test_request_context(f"/auth/oauth/github?{query_string}"):
+ resource.get("github")
+
+ mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
+ mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
+
+ @pytest.mark.parametrize(
+ ("provider", "expected_error"),
+ [
+ ("invalid_provider", "Invalid provider"),
+ ("github", "Invalid provider"), # When GitHub is not configured
+ ("google", "Invalid provider"), # When Google is not configured
+ ],
+ )
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ def test_should_return_error_for_invalid_providers(
+ self, mock_get_providers, resource, app, provider, expected_error
+ ):
+ mock_get_providers.return_value = {"github": None, "google": None}
+
+ with app.test_request_context(f"/auth/oauth/{provider}"):
+ response, status_code = resource.get(provider)
+
+ assert status_code == 400
+ assert response["error"] == expected_error
+
+
+class TestOAuthCallback:
+ @pytest.fixture
+ def resource(self):
+ return OAuthCallback()
+
+ @pytest.fixture
+ def app(self):
+ app = Flask(__name__)
+ app.config["TESTING"] = True
+ return app
+
+ @pytest.fixture
+ def oauth_setup(self):
+ """Common OAuth setup for callback tests"""
+ oauth_provider = MagicMock()
+ oauth_provider.get_access_token.return_value = "access_token"
+ oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
+
+ account = MagicMock()
+ account.status = AccountStatus.ACTIVE.value
+
+ token_pair = MagicMock()
+ token_pair.access_token = "jwt_access_token"
+ token_pair.refresh_token = "jwt_refresh_token"
+
+ return {"provider": oauth_provider, "account": account, "token_pair": token_pair}
+
+ @patch("controllers.console.auth.oauth.dify_config")
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ @patch("controllers.console.auth.oauth._generate_account")
+ @patch("controllers.console.auth.oauth.AccountService")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.redirect")
+ def test_should_handle_successful_oauth_callback(
+ self,
+ mock_redirect,
+ mock_tenant_service,
+ mock_account_service,
+ mock_generate_account,
+ mock_get_providers,
+ mock_config,
+ resource,
+ app,
+ oauth_setup,
+ ):
+ mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
+ mock_get_providers.return_value = {"github": oauth_setup["provider"]}
+ mock_generate_account.return_value = oauth_setup["account"]
+ mock_account_service.login.return_value = oauth_setup["token_pair"]
+
+ with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
+ resource.get("github")
+
+ oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
+ oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
+ mock_redirect.assert_called_once_with(
+ "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
+ )
+
+ @pytest.mark.parametrize(
+ ("exception", "expected_error"),
+ [
+ (Exception("OAuth error"), "OAuth process failed"),
+ (ValueError("Invalid token"), "OAuth process failed"),
+ (KeyError("Missing key"), "OAuth process failed"),
+ ],
+ )
+ @patch("controllers.console.auth.oauth.db")
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ def test_should_handle_oauth_exceptions(
+ self, mock_get_providers, mock_db, resource, app, exception, expected_error
+ ):
+ # Mock database session
+ mock_db.session = MagicMock()
+ mock_db.session.rollback = MagicMock()
+
+ # Import the real requests module to create a proper exception
+ import requests
+
+ request_exception = requests.exceptions.RequestException("OAuth error")
+ request_exception.response = MagicMock()
+ request_exception.response.text = str(exception)
+
+ mock_oauth_provider = MagicMock()
+ mock_oauth_provider.get_access_token.side_effect = request_exception
+ mock_get_providers.return_value = {"github": mock_oauth_provider}
+
+ with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
+ response, status_code = resource.get("github")
+
+ assert status_code == 400
+ assert response["error"] == expected_error
+
+ @pytest.mark.parametrize(
+ ("account_status", "expected_redirect"),
+ [
+ (AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."),
+ # CLOSED status: Currently NOT handled, will proceed to login (security issue)
+ # This documents actual behavior. See test_defensive_check_for_closed_account_status for details
+ (
+ AccountStatus.CLOSED.value,
+ "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
+ ),
+ ],
+ )
+ @patch("controllers.console.auth.oauth.AccountService")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.db")
+ @patch("controllers.console.auth.oauth.dify_config")
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ @patch("controllers.console.auth.oauth._generate_account")
+ @patch("controllers.console.auth.oauth.redirect")
+ def test_should_redirect_based_on_account_status(
+ self,
+ mock_redirect,
+ mock_generate_account,
+ mock_get_providers,
+ mock_config,
+ mock_db,
+ mock_tenant_service,
+ mock_account_service,
+ resource,
+ app,
+ oauth_setup,
+ account_status,
+ expected_redirect,
+ ):
+ # Mock database session
+ mock_db.session = MagicMock()
+ mock_db.session.rollback = MagicMock()
+ mock_db.session.commit = MagicMock()
+
+ mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
+ mock_get_providers.return_value = {"github": oauth_setup["provider"]}
+
+ account = MagicMock()
+ account.status = account_status
+ account.id = "123"
+ mock_generate_account.return_value = account
+
+ # Mock login for CLOSED status
+ mock_token_pair = MagicMock()
+ mock_token_pair.access_token = "jwt_access_token"
+ mock_token_pair.refresh_token = "jwt_refresh_token"
+ mock_account_service.login.return_value = mock_token_pair
+
+ with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
+ resource.get("github")
+
+ mock_redirect.assert_called_once_with(expected_redirect)
+
+ @patch("controllers.console.auth.oauth.dify_config")
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ @patch("controllers.console.auth.oauth._generate_account")
+ @patch("controllers.console.auth.oauth.db")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.AccountService")
+ def test_should_activate_pending_account(
+ self,
+ mock_account_service,
+ mock_tenant_service,
+ mock_db,
+ mock_generate_account,
+ mock_get_providers,
+ mock_config,
+ resource,
+ app,
+ oauth_setup,
+ ):
+ mock_get_providers.return_value = {"github": oauth_setup["provider"]}
+
+ mock_account = MagicMock()
+ mock_account.status = AccountStatus.PENDING.value
+ mock_generate_account.return_value = mock_account
+
+ with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
+ resource.get("github")
+
+ assert mock_account.status == AccountStatus.ACTIVE.value
+ assert mock_account.initialized_at is not None
+ mock_db.session.commit.assert_called_once()
+
+ @patch("controllers.console.auth.oauth.dify_config")
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ @patch("controllers.console.auth.oauth._generate_account")
+ @patch("controllers.console.auth.oauth.db")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.AccountService")
+ @patch("controllers.console.auth.oauth.redirect")
+ def test_defensive_check_for_closed_account_status(
+ self,
+ mock_redirect,
+ mock_account_service,
+ mock_tenant_service,
+ mock_db,
+ mock_generate_account,
+ mock_get_providers,
+ mock_config,
+ resource,
+ app,
+ oauth_setup,
+ ):
+ """Defensive test for CLOSED account status handling in OAuth callback.
+
+ This is a defensive test documenting expected security behavior for CLOSED accounts.
+
+ Current behavior: CLOSED status is NOT checked, allowing closed accounts to login.
+ Expected behavior: CLOSED accounts should be rejected like BANNED accounts.
+
+ Context:
+ - AccountStatus.CLOSED is defined in the enum but never used in production
+ - The close_account() method exists but is never called
+ - Account deletion uses external service instead of status change
+ - All authentication services (OAuth, password, email) don't check CLOSED status
+
+ TODO: If CLOSED status is implemented in the future:
+ 1. Update OAuth callback to check for CLOSED status
+ 2. Add similar checks to all authentication services for consistency
+ 3. Update this test to verify the rejection behavior
+
+ Security consideration: Until properly implemented, CLOSED status provides no protection.
+ """
+ # Setup
+ mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
+ mock_get_providers.return_value = {"github": oauth_setup["provider"]}
+
+ # Create account with CLOSED status
+ closed_account = MagicMock()
+ closed_account.status = AccountStatus.CLOSED.value
+ closed_account.id = "123"
+ closed_account.name = "Closed Account"
+ mock_generate_account.return_value = closed_account
+
+ # Mock successful login (current behavior)
+ mock_token_pair = MagicMock()
+ mock_token_pair.access_token = "jwt_access_token"
+ mock_token_pair.refresh_token = "jwt_refresh_token"
+ mock_account_service.login.return_value = mock_token_pair
+
+ # Execute OAuth callback
+ with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
+ resource.get("github")
+
+ # Verify current behavior: login succeeds (this is NOT ideal)
+ mock_redirect.assert_called_once_with(
+ "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
+ )
+ mock_account_service.login.assert_called_once()
+
+ # Document expected behavior in comments:
+ # Expected: mock_redirect.assert_called_once_with(
+ # "http://localhost:3000/signin?message=Account is closed."
+ # )
+ # Expected: mock_account_service.login.assert_not_called()
+
+
+class TestAccountGeneration:
+ @pytest.fixture
+ def user_info(self):
+ return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
+
+ @pytest.fixture
+ def mock_account(self):
+ account = MagicMock()
+ account.name = "Test User"
+ return account
+
+ @patch("controllers.console.auth.oauth.db")
+ @patch("controllers.console.auth.oauth.Account")
+ @patch("controllers.console.auth.oauth.Session")
+ @patch("controllers.console.auth.oauth.select")
+ def test_should_get_account_by_openid_or_email(
+ self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
+ ):
+ # Mock db.engine for Session creation
+ mock_db.engine = MagicMock()
+
+ # Test OpenID found
+ mock_account_model.get_by_openid.return_value = mock_account
+ result = _get_account_by_openid_or_email("github", user_info)
+ assert result == mock_account
+ mock_account_model.get_by_openid.assert_called_once_with("github", "123")
+
+ # Test fallback to email
+ mock_account_model.get_by_openid.return_value = None
+ mock_session_instance = MagicMock()
+ mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
+ mock_session.return_value.__enter__.return_value = mock_session_instance
+
+ result = _get_account_by_openid_or_email("github", user_info)
+ assert result == mock_account
+
+ @pytest.mark.parametrize(
+ ("allow_register", "existing_account", "should_create"),
+ [
+ (True, None, True), # New account creation allowed
+ (True, "existing", False), # Existing account
+ (False, None, False), # Registration not allowed
+ ],
+ )
+ @patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
+ @patch("controllers.console.auth.oauth.FeatureService")
+ @patch("controllers.console.auth.oauth.RegisterService")
+ @patch("controllers.console.auth.oauth.AccountService")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.db")
+ def test_should_handle_account_generation_scenarios(
+ self,
+ mock_db,
+ mock_tenant_service,
+ mock_account_service,
+ mock_register_service,
+ mock_feature_service,
+ mock_get_account,
+ app,
+ user_info,
+ mock_account,
+ allow_register,
+ existing_account,
+ should_create,
+ ):
+ mock_get_account.return_value = mock_account if existing_account else None
+ mock_feature_service.get_system_features.return_value.is_allow_register = allow_register
+ mock_register_service.register.return_value = mock_account
+
+ with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
+ if not allow_register and not existing_account:
+ with pytest.raises(AccountNotFoundError):
+ _generate_account("github", user_info)
+ else:
+ result = _generate_account("github", user_info)
+ assert result == mock_account
+
+ if should_create:
+ mock_register_service.register.assert_called_once_with(
+ email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
+ )
+
+ @patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.FeatureService")
+ @patch("controllers.console.auth.oauth.AccountService")
+ @patch("controllers.console.auth.oauth.tenant_was_created")
+ def test_should_create_workspace_for_account_without_tenant(
+ self,
+ mock_event,
+ mock_account_service,
+ mock_feature_service,
+ mock_tenant_service,
+ mock_get_account,
+ app,
+ user_info,
+ mock_account,
+ ):
+ mock_get_account.return_value = mock_account
+ mock_tenant_service.get_join_tenants.return_value = []
+ mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
+
+ mock_new_tenant = MagicMock()
+ mock_tenant_service.create_tenant.return_value = mock_new_tenant
+
+ with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
+ result = _generate_account("github", user_info)
+
+ assert result == mock_account
+ mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
+ mock_tenant_service.create_tenant_member.assert_called_once_with(
+ mock_new_tenant, mock_account, role="owner"
+ )
+ mock_event.send.assert_called_once_with(mock_new_tenant)
diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py
new file mode 100644
index 0000000000..9742368f04
--- /dev/null
+++ b/api/tests/unit_tests/controllers/console/test_wraps.py
@@ -0,0 +1,380 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+from flask import Flask
+from flask_login import LoginManager, UserMixin
+
+from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
+from controllers.console.workspace.error import AccountNotInitializedError
+from controllers.console.wraps import (
+ account_initialization_required,
+ cloud_edition_billing_rate_limit_check,
+ cloud_edition_billing_resource_check,
+ enterprise_license_required,
+ only_edition_cloud,
+ only_edition_enterprise,
+ only_edition_self_hosted,
+ setup_required,
+)
+from models.account import AccountStatus
+from services.feature_service import LicenseStatus
+
+
+class MockUser(UserMixin):
+ """Simple User class for testing."""
+
+ def __init__(self, user_id: str):
+ self.id = user_id
+ self.current_tenant_id = "tenant123"
+
+ def get_id(self) -> str:
+ return self.id
+
+
+def create_app_with_login():
+ """Create a Flask app with LoginManager configured."""
+ app = Flask(__name__)
+ app.config["SECRET_KEY"] = "test-secret-key"
+
+ login_manager = LoginManager()
+ login_manager.init_app(app)
+
+ @login_manager.user_loader
+ def load_user(user_id: str):
+ return MockUser(user_id)
+
+ return app
+
+
+class TestAccountInitialization:
+ """Test account initialization decorator"""
+
+ def test_should_allow_initialized_account(self):
+ """Test that initialized accounts can access protected views"""
+ # Arrange
+ mock_user = MagicMock()
+ mock_user.status = AccountStatus.ACTIVE
+
+ @account_initialization_required
+ def protected_view():
+ return "success"
+
+ # Act
+ with patch("controllers.console.wraps.current_user", mock_user):
+ result = protected_view()
+
+ # Assert
+ assert result == "success"
+
+ def test_should_reject_uninitialized_account(self):
+ """Test that uninitialized accounts raise AccountNotInitializedError"""
+ # Arrange
+ mock_user = MagicMock()
+ mock_user.status = AccountStatus.UNINITIALIZED
+
+ @account_initialization_required
+ def protected_view():
+ return "success"
+
+ # Act & Assert
+ with patch("controllers.console.wraps.current_user", mock_user):
+ with pytest.raises(AccountNotInitializedError):
+ protected_view()
+
+
+class TestEditionChecks:
+ """Test edition-specific decorators"""
+
+ def test_only_edition_cloud_allows_cloud_edition(self):
+ """Test cloud edition decorator allows CLOUD edition"""
+
+ # Arrange
+ @only_edition_cloud
+ def cloud_view():
+ return "cloud_success"
+
+ # Act
+ with patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"):
+ result = cloud_view()
+
+ # Assert
+ assert result == "cloud_success"
+
+ def test_only_edition_cloud_rejects_other_editions(self):
+ """Test cloud edition decorator rejects non-CLOUD editions"""
+ # Arrange
+ app = Flask(__name__)
+
+ @only_edition_cloud
+ def cloud_view():
+ return "cloud_success"
+
+ # Act & Assert
+ with app.test_request_context():
+ with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
+ with pytest.raises(Exception) as exc_info:
+ cloud_view()
+ assert exc_info.value.code == 404
+
+ def test_only_edition_enterprise_allows_when_enabled(self):
+ """Test enterprise edition decorator allows when ENTERPRISE_ENABLED is True"""
+
+ # Arrange
+ @only_edition_enterprise
+ def enterprise_view():
+ return "enterprise_success"
+
+ # Act
+ with patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True):
+ result = enterprise_view()
+
+ # Assert
+ assert result == "enterprise_success"
+
+ def test_only_edition_self_hosted_allows_self_hosted(self):
+ """Test self-hosted edition decorator allows SELF_HOSTED edition"""
+
+ # Arrange
+ @only_edition_self_hosted
+ def self_hosted_view():
+ return "self_hosted_success"
+
+ # Act
+ with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
+ result = self_hosted_view()
+
+ # Assert
+ assert result == "self_hosted_success"
+
+
+class TestBillingResourceLimits:
+ """Test billing resource limit decorators"""
+
+ def test_should_allow_when_under_resource_limit(self):
+ """Test that requests are allowed when under resource limits"""
+ # Arrange
+ mock_features = MagicMock()
+ mock_features.billing.enabled = True
+ mock_features.members.limit = 10
+ mock_features.members.size = 5
+
+ @cloud_edition_billing_resource_check("members")
+ def add_member():
+ return "member_added"
+
+ # Act
+ with patch("controllers.console.wraps.current_user"):
+ with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
+ result = add_member()
+
+ # Assert
+ assert result == "member_added"
+
+ def test_should_reject_when_over_resource_limit(self):
+ """Test that requests are rejected when over resource limits"""
+ # Arrange
+ app = create_app_with_login()
+ mock_features = MagicMock()
+ mock_features.billing.enabled = True
+ mock_features.members.limit = 10
+ mock_features.members.size = 10
+
+ @cloud_edition_billing_resource_check("members")
+ def add_member():
+ return "member_added"
+
+ # Act & Assert
+ with app.test_request_context():
+ with patch("controllers.console.wraps.current_user", MockUser("test_user")):
+ with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
+ with pytest.raises(Exception) as exc_info:
+ add_member()
+ assert exc_info.value.code == 403
+ assert "members has reached the limit" in str(exc_info.value.description)
+
+ def test_should_check_source_for_documents_limit(self):
+ """Test document limit checks request source"""
+ # Arrange
+ app = create_app_with_login()
+ mock_features = MagicMock()
+ mock_features.billing.enabled = True
+ mock_features.documents_upload_quota.limit = 100
+ mock_features.documents_upload_quota.size = 100
+
+ @cloud_edition_billing_resource_check("documents")
+ def upload_document():
+ return "document_uploaded"
+
+ # Test 1: Should reject when source is datasets
+ with app.test_request_context("/?source=datasets"):
+ with patch("controllers.console.wraps.current_user", MockUser("test_user")):
+ with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
+ with pytest.raises(Exception) as exc_info:
+ upload_document()
+ assert exc_info.value.code == 403
+
+ # Test 2: Should allow when source is not datasets
+ with app.test_request_context("/?source=other"):
+ with patch("controllers.console.wraps.current_user", MockUser("test_user")):
+ with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
+ result = upload_document()
+ assert result == "document_uploaded"
+
+
+class TestRateLimiting:
+ """Test rate limiting decorator"""
+
+ @patch("controllers.console.wraps.redis_client")
+ @patch("controllers.console.wraps.db")
+ def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis):
+ """Test that requests within rate limit are allowed"""
+ # Arrange
+ mock_rate_limit = MagicMock()
+ mock_rate_limit.enabled = True
+ mock_rate_limit.limit = 10
+ mock_redis.zcard.return_value = 5 # 5 requests in window
+
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ def knowledge_request():
+ return "knowledge_success"
+
+ # Act
+ with patch("controllers.console.wraps.current_user"):
+ with patch(
+ "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
+ ):
+ result = knowledge_request()
+
+ # Assert
+ assert result == "knowledge_success"
+ mock_redis.zadd.assert_called_once()
+ mock_redis.zremrangebyscore.assert_called_once()
+
+ @patch("controllers.console.wraps.redis_client")
+ @patch("controllers.console.wraps.db")
+ def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis):
+ """Test that requests over rate limit are rejected and logged"""
+ # Arrange
+ app = create_app_with_login()
+ mock_rate_limit = MagicMock()
+ mock_rate_limit.enabled = True
+ mock_rate_limit.limit = 10
+ mock_rate_limit.subscription_plan = "pro"
+ mock_redis.zcard.return_value = 11 # Over limit
+
+ mock_session = MagicMock()
+ mock_db.session = mock_session
+
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ def knowledge_request():
+ return "knowledge_success"
+
+ # Act & Assert
+ with app.test_request_context():
+ with patch("controllers.console.wraps.current_user", MockUser("test_user")):
+ with patch(
+ "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
+ ):
+ with pytest.raises(Exception) as exc_info:
+ knowledge_request()
+
+ # Verify error
+ assert exc_info.value.code == 403
+ assert "rate limit" in str(exc_info.value.description)
+
+ # Verify rate limit log was created
+ mock_session.add.assert_called_once()
+ mock_session.commit.assert_called_once()
+
+
+class TestSystemSetup:
+ """Test system setup decorator"""
+
+ @patch("controllers.console.wraps.db")
+ def test_should_allow_when_setup_complete(self, mock_db):
+ """Test that requests are allowed when setup is complete"""
+ # Arrange
+ mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
+
+ @setup_required
+ def admin_view():
+ return "admin_success"
+
+ # Act
+ with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
+ result = admin_view()
+
+ # Assert
+ assert result == "admin_success"
+
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.wraps.os.environ.get")
+ def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
+ """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
+ # Arrange
+ mock_db.session.query.return_value.first.return_value = None # No setup
+ mock_environ_get.return_value = "some_password"
+
+ @setup_required
+ def admin_view():
+ return "admin_success"
+
+ # Act & Assert
+ with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
+ with pytest.raises(NotInitValidateError):
+ admin_view()
+
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.wraps.os.environ.get")
+ def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
+ """Test NotSetupError when no INIT_PASSWORD and setup not complete"""
+ # Arrange
+ mock_db.session.query.return_value.first.return_value = None # No setup
+ mock_environ_get.return_value = None # No INIT_PASSWORD
+
+ @setup_required
+ def admin_view():
+ return "admin_success"
+
+ # Act & Assert
+ with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
+ with pytest.raises(NotSetupError):
+ admin_view()
+
+
+class TestEnterpriseLicense:
+ """Test enterprise license decorator"""
+
+ def test_should_allow_with_valid_license(self):
+ """Test that valid licenses allow access"""
+ # Arrange
+ mock_settings = MagicMock()
+ mock_settings.license.status = LicenseStatus.ACTIVE
+
+ @enterprise_license_required
+ def enterprise_feature():
+ return "enterprise_success"
+
+ # Act
+ with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
+ result = enterprise_feature()
+
+ # Assert
+ assert result == "enterprise_success"
+
+ @pytest.mark.parametrize("invalid_status", [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST])
+ def test_should_reject_with_invalid_license(self, invalid_status):
+ """Test that invalid licenses raise UnauthorizedAndForceLogout"""
+ # Arrange
+ mock_settings = MagicMock()
+ mock_settings.license.status = invalid_status
+
+ @enterprise_license_required
+ def enterprise_feature():
+ return "enterprise_success"
+
+ # Act & Assert
+ with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
+ with pytest.raises(UnauthorizedAndForceLogout) as exc_info:
+ enterprise_feature()
+ assert "license is invalid" in str(exc_info.value)
diff --git a/api/tests/unit_tests/core/helper/test_trace_id_helper.py b/api/tests/unit_tests/core/helper/test_trace_id_helper.py
new file mode 100644
index 0000000000..27bfe1af05
--- /dev/null
+++ b/api/tests/unit_tests/core/helper/test_trace_id_helper.py
@@ -0,0 +1,86 @@
+import pytest
+
+from core.helper.trace_id_helper import extract_external_trace_id_from_args, get_external_trace_id, is_valid_trace_id
+
+
+class DummyRequest:
+ def __init__(self, headers=None, args=None, json=None, is_json=False):
+ self.headers = headers or {}
+ self.args = args or {}
+ self.json = json
+ self.is_json = is_json
+
+
+class TestTraceIdHelper:
+ """Test cases for trace_id_helper.py"""
+
+ @pytest.mark.parametrize(
+ ("trace_id", "expected"),
+ [
+ ("abc123", True),
+ ("A-B_C-123", True),
+ ("a" * 128, True),
+ ("", False),
+ ("a" * 129, False),
+ ("abc!@#", False),
+ ("空格", False),
+ ("with space", False),
+ ],
+ )
+ def test_is_valid_trace_id(self, trace_id, expected):
+ """Test trace_id validation for various cases"""
+ assert is_valid_trace_id(trace_id) is expected
+
+ def test_get_external_trace_id_from_header(self):
+ """Should extract valid trace_id from header"""
+ req = DummyRequest(headers={"X-Trace-Id": "abc123"})
+ assert get_external_trace_id(req) == "abc123"
+
+ def test_get_external_trace_id_from_args(self):
+ """Should extract valid trace_id from args if header missing"""
+ req = DummyRequest(args={"trace_id": "abc123"})
+ assert get_external_trace_id(req) == "abc123"
+
+ def test_get_external_trace_id_from_json(self):
+ """Should extract valid trace_id from JSON body if header and args missing"""
+ req = DummyRequest(is_json=True, json={"trace_id": "abc123"})
+ assert get_external_trace_id(req) == "abc123"
+
+ def test_get_external_trace_id_priority(self):
+ """Header > args > json priority"""
+ req = DummyRequest(
+ headers={"X-Trace-Id": "header_id"},
+ args={"trace_id": "args_id"},
+ is_json=True,
+ json={"trace_id": "json_id"},
+ )
+ assert get_external_trace_id(req) == "header_id"
+ req2 = DummyRequest(args={"trace_id": "args_id"}, is_json=True, json={"trace_id": "json_id"})
+ assert get_external_trace_id(req2) == "args_id"
+ req3 = DummyRequest(is_json=True, json={"trace_id": "json_id"})
+ assert get_external_trace_id(req3) == "json_id"
+
+ @pytest.mark.parametrize(
+ "req",
+ [
+ DummyRequest(headers={"X-Trace-Id": "!!!"}),
+ DummyRequest(args={"trace_id": "!!!"}),
+ DummyRequest(is_json=True, json={"trace_id": "!!!"}),
+ DummyRequest(),
+ ],
+ )
+ def test_get_external_trace_id_invalid(self, req):
+ """Should return None for invalid or missing trace_id"""
+ assert get_external_trace_id(req) is None
+
+ @pytest.mark.parametrize(
+ ("args", "expected"),
+ [
+ ({"external_trace_id": "abc123"}, {"external_trace_id": "abc123"}),
+ ({"other": "value"}, {}),
+ ({}, {}),
+ ],
+ )
+ def test_extract_external_trace_id_from_args(self, args, expected):
+ """Test extraction of external_trace_id from args mapping"""
+ assert extract_external_trace_id_from_args(args) == expected
diff --git a/api/tests/unit_tests/core/helper/test_url_signer.py b/api/tests/unit_tests/core/helper/test_url_signer.py
deleted file mode 100644
index 5af24777de..0000000000
--- a/api/tests/unit_tests/core/helper/test_url_signer.py
+++ /dev/null
@@ -1,194 +0,0 @@
-from unittest.mock import patch
-from urllib.parse import parse_qs, urlparse
-
-import pytest
-
-from core.helper.url_signer import SignedUrlParams, UrlSigner
-
-
-class TestUrlSigner:
- """Test cases for UrlSigner class"""
-
- @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
- def test_should_generate_signed_url_params(self):
- """Test generation of signed URL parameters with all required fields"""
- sign_key = "test-sign-key"
- prefix = "test-prefix"
-
- params = UrlSigner.get_signed_url_params(sign_key, prefix)
-
- # Verify the returned object and required fields
- assert isinstance(params, SignedUrlParams)
- assert params.sign_key == sign_key
- assert params.timestamp is not None
- assert params.nonce is not None
- assert params.sign is not None
-
- # Verify nonce format (32 character hex string)
- assert len(params.nonce) == 32
- assert all(c in "0123456789abcdef" for c in params.nonce)
-
- @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
- def test_should_generate_complete_signed_url(self):
- """Test generation of complete signed URL with query parameters"""
- base_url = "https://example.com/api/test"
- sign_key = "test-sign-key"
- prefix = "test-prefix"
-
- signed_url = UrlSigner.get_signed_url(base_url, sign_key, prefix)
-
- # Parse URL and verify structure
- parsed = urlparse(signed_url)
- assert f"{parsed.scheme}://{parsed.netloc}{parsed.path}" == base_url
-
- # Verify query parameters
- query_params = parse_qs(parsed.query)
- assert "timestamp" in query_params
- assert "nonce" in query_params
- assert "sign" in query_params
-
- # Verify each parameter has exactly one value
- assert len(query_params["timestamp"]) == 1
- assert len(query_params["nonce"]) == 1
- assert len(query_params["sign"]) == 1
-
- # Verify parameter values are not empty
- assert query_params["timestamp"][0]
- assert query_params["nonce"][0]
- assert query_params["sign"][0]
-
- @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
- def test_should_verify_valid_signature(self):
- """Test verification of valid signature"""
- sign_key = "test-sign-key"
- prefix = "test-prefix"
-
- # Generate and verify signature
- params = UrlSigner.get_signed_url_params(sign_key, prefix)
-
- is_valid = UrlSigner.verify(
- sign_key=sign_key, timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix=prefix
- )
-
- assert is_valid is True
-
- @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
- @pytest.mark.parametrize(
- ("field", "modifier"),
- [
- ("sign_key", lambda _: "wrong-sign-key"),
- ("timestamp", lambda t: str(int(t) + 1000)),
- ("nonce", lambda _: "different-nonce-123456789012345"),
- ("prefix", lambda _: "wrong-prefix"),
- ("sign", lambda s: s + "tampered"),
- ],
- )
- def test_should_reject_invalid_signature_params(self, field, modifier):
- """Test signature verification rejects invalid parameters"""
- sign_key = "test-sign-key"
- prefix = "test-prefix"
-
- # Generate valid signed parameters
- params = UrlSigner.get_signed_url_params(sign_key, prefix)
-
- # Prepare verification parameters
- verify_params = {
- "sign_key": sign_key,
- "timestamp": params.timestamp,
- "nonce": params.nonce,
- "sign": params.sign,
- "prefix": prefix,
- }
-
- # Modify the specific field
- verify_params[field] = modifier(verify_params[field])
-
- # Verify should fail
- is_valid = UrlSigner.verify(**verify_params)
- assert is_valid is False
-
- @patch("configs.dify_config.SECRET_KEY", None)
- def test_should_raise_error_without_secret_key(self):
- """Test that signing fails when SECRET_KEY is not configured"""
- with pytest.raises(Exception) as exc_info:
- UrlSigner.get_signed_url_params("key", "prefix")
-
- assert "SECRET_KEY is not set" in str(exc_info.value)
-
- @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
- def test_should_generate_unique_signatures(self):
- """Test that different inputs produce different signatures"""
- params1 = UrlSigner.get_signed_url_params("key1", "prefix1")
- params2 = UrlSigner.get_signed_url_params("key2", "prefix2")
-
- # Different inputs should produce different signatures
- assert params1.sign != params2.sign
- assert params1.nonce != params2.nonce
-
- @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
- def test_should_handle_special_characters(self):
- """Test handling of special characters in parameters"""
- special_cases = [
- "test with spaces",
- "test/with/slashes",
- "test中文字符",
- ]
-
- for sign_key in special_cases:
- params = UrlSigner.get_signed_url_params(sign_key, "prefix")
-
- # Should generate valid signature and verify correctly
- is_valid = UrlSigner.verify(
- sign_key=sign_key, timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix="prefix"
- )
- assert is_valid is True
-
- @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
- def test_should_ensure_nonce_randomness(self):
- """Test that nonce is random for each generation - critical for security"""
- sign_key = "test-sign-key"
- prefix = "test-prefix"
-
- # Generate multiple nonces
- nonces = set()
- for _ in range(5):
- params = UrlSigner.get_signed_url_params(sign_key, prefix)
- nonces.add(params.nonce)
-
- # All nonces should be unique
- assert len(nonces) == 5
-
- @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
- @patch("time.time", return_value=1234567890)
- @patch("os.urandom", return_value=b"\xab\xcd\xef\x12\x34\x56\x78\x90\xab\xcd\xef\x12\x34\x56\x78\x90")
- def test_should_produce_consistent_signatures(self, mock_urandom, mock_time):
- """Test that same inputs produce same signature - ensures deterministic behavior"""
- sign_key = "test-sign-key"
- prefix = "test-prefix"
-
- # Generate signature multiple times with same inputs (time and nonce are mocked)
- params1 = UrlSigner.get_signed_url_params(sign_key, prefix)
- params2 = UrlSigner.get_signed_url_params(sign_key, prefix)
-
- # With mocked time and random, should produce identical results
- assert params1.timestamp == params2.timestamp
- assert params1.nonce == params2.nonce
- assert params1.sign == params2.sign
-
- # Verify the signature is valid
- assert UrlSigner.verify(
- sign_key=sign_key, timestamp=params1.timestamp, nonce=params1.nonce, sign=params1.sign, prefix=prefix
- )
-
- @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
- def test_should_handle_empty_strings(self):
- """Test handling of empty string parameters - common edge case"""
- # Empty sign_key and prefix should still work
- params = UrlSigner.get_signed_url_params("", "")
- assert params.sign is not None
-
- # Should verify correctly
- is_valid = UrlSigner.verify(
- sign_key="", timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix=""
- )
- assert is_valid is True
diff --git a/api/tests/unit_tests/core/repositories/__init__.py b/api/tests/unit_tests/core/repositories/__init__.py
new file mode 100644
index 0000000000..c65d7da61d
--- /dev/null
+++ b/api/tests/unit_tests/core/repositories/__init__.py
@@ -0,0 +1 @@
+# Unit tests for core repositories module
diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py
new file mode 100644
index 0000000000..fce4a6fb6b
--- /dev/null
+++ b/api/tests/unit_tests/core/repositories/test_factory.py
@@ -0,0 +1,455 @@
+"""
+Unit tests for the RepositoryFactory.
+
+This module tests the factory pattern implementation for creating repository instances
+based on configuration, including error handling and validation.
+"""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+from pytest_mock import MockerFixture
+from sqlalchemy.engine import Engine
+from sqlalchemy.orm import sessionmaker
+
+from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
+from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from models import Account, EndUser
+from models.enums import WorkflowRunTriggeredFrom
+from models.workflow import WorkflowNodeExecutionTriggeredFrom
+
+
+class TestRepositoryFactory:
+ """Test cases for RepositoryFactory."""
+
+ def test_import_class_success(self):
+ """Test successful class import."""
+ # Test importing a real class
+ class_path = "unittest.mock.MagicMock"
+ result = DifyCoreRepositoryFactory._import_class(class_path)
+ assert result is MagicMock
+
+ def test_import_class_invalid_path(self):
+ """Test import with invalid module path."""
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._import_class("invalid.module.path")
+ assert "Cannot import repository class" in str(exc_info.value)
+
+ def test_import_class_invalid_class_name(self):
+ """Test import with invalid class name."""
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass")
+ assert "Cannot import repository class" in str(exc_info.value)
+
+ def test_import_class_malformed_path(self):
+ """Test import with malformed path (no dots)."""
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._import_class("invalidpath")
+ assert "Cannot import repository class" in str(exc_info.value)
+
+ def test_validate_repository_interface_success(self):
+ """Test successful interface validation."""
+
+ # Create a mock class that implements the required methods
+ class MockRepository:
+ def save(self):
+ pass
+
+ def get_by_id(self):
+ pass
+
+ # Create a mock interface with the same methods
+ class MockInterface:
+ def save(self):
+ pass
+
+ def get_by_id(self):
+ pass
+
+ # Should not raise an exception
+ DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
+
+ def test_validate_repository_interface_missing_methods(self):
+ """Test interface validation with missing methods."""
+
+ # Create a mock class that doesn't implement all required methods
+ class IncompleteRepository:
+ def save(self):
+ pass
+
+ # Missing get_by_id method
+
+ # Create a mock interface with required methods
+ class MockInterface:
+ def save(self):
+ pass
+
+ def get_by_id(self):
+ pass
+
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
+ assert "does not implement required methods" in str(exc_info.value)
+ assert "get_by_id" in str(exc_info.value)
+
+ def test_validate_constructor_signature_success(self):
+ """Test successful constructor signature validation."""
+
+ class MockRepository:
+ def __init__(self, session_factory, user, app_id, triggered_from):
+ pass
+
+ # Should not raise an exception
+ DifyCoreRepositoryFactory._validate_constructor_signature(
+ MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+
+ def test_validate_constructor_signature_missing_params(self):
+ """Test constructor validation with missing parameters."""
+
+ class IncompleteRepository:
+ def __init__(self, session_factory, user):
+ # Missing app_id and triggered_from parameters
+ pass
+
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._validate_constructor_signature(
+ IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+ assert "does not accept required parameters" in str(exc_info.value)
+ assert "app_id" in str(exc_info.value)
+ assert "triggered_from" in str(exc_info.value)
+
+ def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
+ """Test constructor validation when inspection fails."""
+ # Mock inspect.signature to raise an exception
+ mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
+
+ class MockRepository:
+ def __init__(self, session_factory):
+ pass
+
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
+ assert "Failed to validate constructor signature" in str(exc_info.value)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
+ """Test successful creation of WorkflowExecutionRepository."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ # Create mock dependencies
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=Account)
+ app_id = "test-app-id"
+ triggered_from = WorkflowRunTriggeredFrom.APP_RUN
+
+ # Mock the imported class to be a valid repository
+ mock_repository_class = MagicMock()
+ mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
+ mock_repository_class.return_value = mock_repository_instance
+
+ # Mock the validation methods
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
+ patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
+ ):
+ result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+
+ # Verify the repository was created with correct parameters
+ mock_repository_class.assert_called_once_with(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+ assert result is mock_repository_instance
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_execution_repository_import_error(self, mock_config):
+ """Test WorkflowExecutionRepository creation with import error."""
+ # Setup mock configuration with invalid class path
+ mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=Account)
+
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
+ )
+ assert "Cannot import repository class" in str(exc_info.value)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
+ """Test WorkflowExecutionRepository creation with validation error."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=Account)
+
+ # Mock import to succeed but validation to fail
+ mock_repository_class = MagicMock()
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(
+ DifyCoreRepositoryFactory,
+ "_validate_repository_interface",
+ side_effect=RepositoryImportError("Interface validation failed"),
+ ),
+ ):
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
+ )
+ assert "Interface validation failed" in str(exc_info.value)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
+ """Test WorkflowExecutionRepository creation with instantiation error."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=Account)
+
+ # Mock import and validation to succeed but instantiation to fail
+ mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
+ patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
+ ):
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
+ )
+ assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
+ """Test successful creation of WorkflowNodeExecutionRepository."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ # Create mock dependencies
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=EndUser)
+ app_id = "test-app-id"
+ triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
+
+ # Mock the imported class to be a valid repository
+ mock_repository_class = MagicMock()
+ mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
+ mock_repository_class.return_value = mock_repository_instance
+
+ # Mock the validation methods
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
+ patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
+ ):
+ result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+
+ # Verify the repository was created with correct parameters
+ mock_repository_class.assert_called_once_with(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+ assert result is mock_repository_instance
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_node_execution_repository_import_error(self, mock_config):
+ """Test WorkflowNodeExecutionRepository creation with import error."""
+ # Setup mock configuration with invalid class path
+ mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=EndUser)
+
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
+ )
+ assert "Cannot import repository class" in str(exc_info.value)
+
+ def test_repository_import_error_exception(self):
+ """Test RepositoryImportError exception."""
+ error_message = "Test error message"
+ exception = RepositoryImportError(error_message)
+ assert str(exception) == error_message
+ assert isinstance(exception, Exception)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
+ """Test repository creation with Engine instead of sessionmaker."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ # Create mock dependencies with Engine instead of sessionmaker
+ mock_engine = MagicMock(spec=Engine)
+ mock_user = MagicMock(spec=Account)
+
+ # Mock the imported class to be a valid repository
+ mock_repository_class = MagicMock()
+ mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
+ mock_repository_class.return_value = mock_repository_instance
+
+ # Mock the validation methods
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
+ patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
+ ):
+ result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
+ session_factory=mock_engine, # Using Engine instead of sessionmaker
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
+ )
+
+ # Verify the repository was created with the Engine
+ mock_repository_class.assert_called_once_with(
+ session_factory=mock_engine,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
+ )
+ assert result is mock_repository_instance
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
+ """Test WorkflowNodeExecutionRepository creation with validation error."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=EndUser)
+
+ # Mock import to succeed but validation to fail
+ mock_repository_class = MagicMock()
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(
+ DifyCoreRepositoryFactory,
+ "_validate_repository_interface",
+ side_effect=RepositoryImportError("Interface validation failed"),
+ ),
+ ):
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
+ )
+ assert "Interface validation failed" in str(exc_info.value)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
+ """Test WorkflowNodeExecutionRepository creation with instantiation error."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=EndUser)
+
+ # Mock import and validation to succeed but instantiation to fail
+ mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
+ patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
+ ):
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
+ )
+ assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
+
+ def test_validate_repository_interface_with_private_methods(self):
+ """Test interface validation ignores private methods."""
+
+ # Create a mock class with private methods
+ class MockRepository:
+ def save(self):
+ pass
+
+ def get_by_id(self):
+ pass
+
+ def _private_method(self):
+ pass
+
+ # Create a mock interface with private methods
+ class MockInterface:
+ def save(self):
+ pass
+
+ def get_by_id(self):
+ pass
+
+ def _private_method(self):
+ pass
+
+ # Should not raise an exception (private methods are ignored)
+ DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
+
+ def test_validate_constructor_signature_with_extra_params(self):
+ """Test constructor validation with extra parameters (should pass)."""
+
+ class MockRepository:
+ def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
+ pass
+
+ # Should not raise an exception (extra parameters are allowed)
+ DifyCoreRepositoryFactory._validate_constructor_signature(
+ MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+
+ def test_validate_constructor_signature_with_kwargs(self):
+ """Test constructor validation with **kwargs (current implementation doesn't support this)."""
+
+ class MockRepository:
+ def __init__(self, session_factory, user, **kwargs):
+ pass
+
+ # Current implementation doesn't handle **kwargs, so this should raise an exception
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._validate_constructor_signature(
+ MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+ assert "does not accept required parameters" in str(exc_info.value)
+ assert "app_id" in str(exc_info.value)
+ assert "triggered_from" in str(exc_info.value)
diff --git a/api/tests/unit_tests/core/tools/utils/__init__.py b/api/tests/unit_tests/core/tools/utils/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py
new file mode 100644
index 0000000000..8e07293ce0
--- /dev/null
+++ b/api/tests/unit_tests/core/tools/utils/test_parser.py
@@ -0,0 +1,56 @@
+import pytest
+from flask import Flask
+
+from core.tools.utils.parser import ApiBasedToolSchemaParser
+
+
+@pytest.fixture
+def app():
+ app = Flask(__name__)
+ return app
+
+
+def test_parse_openapi_to_tool_bundle_operation_id(app):
+ openapi = {
+ "openapi": "3.0.0",
+ "info": {"title": "Simple API", "version": "1.0.0"},
+ "servers": [{"url": "http://localhost:3000"}],
+ "paths": {
+ "/": {
+ "get": {
+ "summary": "Root endpoint",
+ "responses": {
+ "200": {
+ "description": "Successful response",
+ }
+ },
+ }
+ },
+ "/api/resources": {
+ "get": {
+ "summary": "Non-root endpoint without an operationId",
+ "responses": {
+ "200": {
+ "description": "Successful response",
+ }
+ },
+ },
+ "post": {
+ "summary": "Non-root endpoint with an operationId",
+ "operationId": "createResource",
+ "responses": {
+ "201": {
+ "description": "Resource created",
+ }
+ },
+ },
+ },
+ },
+ }
+ with app.test_request_context():
+ tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
+
+ assert len(tool_bundles) == 3
+ assert tool_bundles[0].operation_id == "
_get"
+ assert tool_bundles[1].operation_id == "apiresources_get"
+ assert tool_bundles[2].operation_id == "createResource"
diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py
index 1b035d01a7..4c8d983d20 100644
--- a/api/tests/unit_tests/core/variables/test_segment.py
+++ b/api/tests/unit_tests/core/variables/test_segment.py
@@ -1,14 +1,49 @@
+import dataclasses
+
+from pydantic import BaseModel
+
+from core.file import File, FileTransferMethod, FileType
from core.helper import encrypter
-from core.variables import SecretVariable, StringVariable
+from core.variables.segments import (
+ ArrayAnySegment,
+ ArrayFileSegment,
+ ArrayNumberSegment,
+ ArrayObjectSegment,
+ ArrayStringSegment,
+ FileSegment,
+ FloatSegment,
+ IntegerSegment,
+ NoneSegment,
+ ObjectSegment,
+ Segment,
+ SegmentUnion,
+ StringSegment,
+ get_segment_discriminator,
+)
+from core.variables.types import SegmentType
+from core.variables.variables import (
+ ArrayAnyVariable,
+ ArrayFileVariable,
+ ArrayNumberVariable,
+ ArrayObjectVariable,
+ ArrayStringVariable,
+ FileVariable,
+ FloatVariable,
+ IntegerVariable,
+ NoneVariable,
+ ObjectVariable,
+ SecretVariable,
+ StringVariable,
+ Variable,
+ VariableUnion,
+)
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
+from core.workflow.system_variable import SystemVariable
def test_segment_group_to_text():
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey("user_id"): "fake-user-id",
- },
+ system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
@@ -30,7 +65,7 @@ def test_segment_group_to_text():
def test_convert_constant_to_segment_group():
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group():
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey("user_id"): "fake-user-id",
- },
+ system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group():
assert segments_group.log == "fake-user-id"
assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id"
+
+
+class _Segments(BaseModel):
+ segments: list[SegmentUnion]
+
+
+class _Variables(BaseModel):
+ variables: list[VariableUnion]
+
+
+def create_test_file(
+ file_type: FileType = FileType.DOCUMENT,
+ transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
+ filename: str = "test.txt",
+ extension: str = ".txt",
+ mime_type: str = "text/plain",
+ size: int = 1024,
+) -> File:
+ """Factory function to create File objects for testing"""
+ return File(
+ tenant_id="test-tenant",
+ type=file_type,
+ transfer_method=transfer_method,
+ filename=filename,
+ extension=extension,
+ mime_type=mime_type,
+ size=size,
+ related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
+ remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
+ storage_key="test-storage-key",
+ )
+
+
+class TestSegmentDumpAndLoad:
+ """Test suite for segment and variable serialization/deserialization"""
+
+ def test_segments(self):
+ """Test basic segment serialization compatibility"""
+ model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
+ json = model.model_dump_json()
+ print("Json: ", json)
+ loaded = _Segments.model_validate_json(json)
+ assert loaded == model
+
+ def test_segment_number(self):
+ """Test number segment serialization compatibility"""
+ model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
+ json = model.model_dump_json()
+ print("Json: ", json)
+ loaded = _Segments.model_validate_json(json)
+ assert loaded == model
+
+ def test_variables(self):
+ """Test variable serialization compatibility"""
+ model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
+ json = model.model_dump_json()
+ print("Json: ", json)
+ restored = _Variables.model_validate_json(json)
+ assert restored == model
+
+ def test_all_segments_serialization(self):
+ """Test serialization/deserialization of all segment types"""
+ # Create one instance of each segment type
+ test_file = create_test_file()
+
+ all_segments: list[SegmentUnion] = [
+ NoneSegment(),
+ StringSegment(value="test string"),
+ IntegerSegment(value=42),
+ FloatSegment(value=3.14),
+ ObjectSegment(value={"key": "value", "number": 123}),
+ FileSegment(value=test_file),
+ ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
+ ArrayStringSegment(value=["hello", "world"]),
+ ArrayNumberSegment(value=[1, 2.5, 3]),
+ ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
+ ArrayFileSegment(value=[]), # Empty array to avoid file complexity
+ ]
+
+ # Test serialization and deserialization
+ model = _Segments(segments=all_segments)
+ json_str = model.model_dump_json()
+ loaded = _Segments.model_validate_json(json_str)
+
+ # Verify all segments are preserved
+ assert len(loaded.segments) == len(all_segments)
+
+ for original, loaded_segment in zip(all_segments, loaded.segments):
+ assert type(loaded_segment) == type(original)
+ assert loaded_segment.value_type == original.value_type
+
+ # For file segments, compare key properties instead of exact equality
+ if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
+ orig_file = original.value
+ loaded_file = loaded_segment.value
+ assert isinstance(orig_file, File)
+ assert isinstance(loaded_file, File)
+ assert loaded_file.tenant_id == orig_file.tenant_id
+ assert loaded_file.type == orig_file.type
+ assert loaded_file.filename == orig_file.filename
+ else:
+ assert loaded_segment.value == original.value
+
+ def test_all_variables_serialization(self):
+ """Test serialization/deserialization of all variable types"""
+ # Create one instance of each variable type
+ test_file = create_test_file()
+
+ all_variables: list[VariableUnion] = [
+ NoneVariable(name="none_var"),
+ StringVariable(value="test string", name="string_var"),
+ IntegerVariable(value=42, name="int_var"),
+ FloatVariable(value=3.14, name="float_var"),
+ ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
+ FileVariable(value=test_file, name="file_var"),
+ ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
+ ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
+ ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
+ ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
+ ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
+ ]
+
+ # Test serialization and deserialization
+ model = _Variables(variables=all_variables)
+ json_str = model.model_dump_json()
+ loaded = _Variables.model_validate_json(json_str)
+
+ # Verify all variables are preserved
+ assert len(loaded.variables) == len(all_variables)
+
+ for original, loaded_variable in zip(all_variables, loaded.variables):
+ assert type(loaded_variable) == type(original)
+ assert loaded_variable.value_type == original.value_type
+ assert loaded_variable.name == original.name
+
+ # For file variables, compare key properties instead of exact equality
+ if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
+ orig_file = original.value
+ loaded_file = loaded_variable.value
+ assert isinstance(orig_file, File)
+ assert isinstance(loaded_file, File)
+ assert loaded_file.tenant_id == orig_file.tenant_id
+ assert loaded_file.type == orig_file.type
+ assert loaded_file.filename == orig_file.filename
+ else:
+ assert loaded_variable.value == original.value
+
+ def test_segment_discriminator_function_for_segment_types(self):
+ """Test the segment discriminator function"""
+
+ @dataclasses.dataclass
+ class TestCase:
+ segment: Segment
+ expected_segment_type: SegmentType
+
+ file1 = create_test_file()
+ file2 = create_test_file(filename="test2.txt")
+
+ cases = [
+ TestCase(
+ NoneSegment(),
+ SegmentType.NONE,
+ ),
+ TestCase(
+ StringSegment(value=""),
+ SegmentType.STRING,
+ ),
+ TestCase(
+ FloatSegment(value=0.0),
+ SegmentType.FLOAT,
+ ),
+ TestCase(
+ IntegerSegment(value=0),
+ SegmentType.INTEGER,
+ ),
+ TestCase(
+ ObjectSegment(value={}),
+ SegmentType.OBJECT,
+ ),
+ TestCase(
+ FileSegment(value=file1),
+ SegmentType.FILE,
+ ),
+ TestCase(
+ ArrayAnySegment(value=[0, 0.0, ""]),
+ SegmentType.ARRAY_ANY,
+ ),
+ TestCase(
+ ArrayStringSegment(value=[""]),
+ SegmentType.ARRAY_STRING,
+ ),
+ TestCase(
+ ArrayNumberSegment(value=[0, 0.0]),
+ SegmentType.ARRAY_NUMBER,
+ ),
+ TestCase(
+ ArrayObjectSegment(value=[{}]),
+ SegmentType.ARRAY_OBJECT,
+ ),
+ TestCase(
+ ArrayFileSegment(value=[file1, file2]),
+ SegmentType.ARRAY_FILE,
+ ),
+ ]
+
+ for test_case in cases:
+ segment = test_case.segment
+ assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
+ f"get_segment_discriminator failed for type {type(segment)}"
+ )
+ model_dict = segment.model_dump(mode="json")
+ assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
+ f"get_segment_discriminator failed for serialized form of type {type(segment)}"
+ )
+
+ def test_variable_discriminator_function_for_variable_types(self):
+ """Test the variable discriminator function"""
+
+ @dataclasses.dataclass
+ class TestCase:
+ variable: Variable
+ expected_segment_type: SegmentType
+
+ file1 = create_test_file()
+ file2 = create_test_file(filename="test2.txt")
+
+ cases = [
+ TestCase(
+ NoneVariable(name="none_var"),
+ SegmentType.NONE,
+ ),
+ TestCase(
+ StringVariable(value="test", name="string_var"),
+ SegmentType.STRING,
+ ),
+ TestCase(
+ FloatVariable(value=0.0, name="float_var"),
+ SegmentType.FLOAT,
+ ),
+ TestCase(
+ IntegerVariable(value=0, name="int_var"),
+ SegmentType.INTEGER,
+ ),
+ TestCase(
+ ObjectVariable(value={}, name="object_var"),
+ SegmentType.OBJECT,
+ ),
+ TestCase(
+ FileVariable(value=file1, name="file_var"),
+ SegmentType.FILE,
+ ),
+ TestCase(
+ SecretVariable(value="secret", name="secret_var"),
+ SegmentType.SECRET,
+ ),
+ TestCase(
+ ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
+ SegmentType.ARRAY_ANY,
+ ),
+ TestCase(
+ ArrayStringVariable(value=[""], name="array_string_var"),
+ SegmentType.ARRAY_STRING,
+ ),
+ TestCase(
+ ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
+ SegmentType.ARRAY_NUMBER,
+ ),
+ TestCase(
+ ArrayObjectVariable(value=[{}], name="array_object_var"),
+ SegmentType.ARRAY_OBJECT,
+ ),
+ TestCase(
+ ArrayFileVariable(value=[file1, file2], name="array_file_var"),
+ SegmentType.ARRAY_FILE,
+ ),
+ ]
+
+ for test_case in cases:
+ variable = test_case.variable
+ assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
+ f"get_segment_discriminator failed for type {type(variable)}"
+ )
+ model_dict = variable.model_dump(mode="json")
+ assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
+ f"get_segment_discriminator failed for serialized form of type {type(variable)}"
+ )
+
+ def test_invalid_value_for_discriminator(self):
+ # Test invalid cases
+ assert get_segment_discriminator({"value_type": "invalid"}) is None
+ assert get_segment_discriminator({}) is None
+ assert get_segment_discriminator("not_a_dict") is None
+ assert get_segment_discriminator(42) is None
+ assert get_segment_discriminator(object) is None
diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py
new file mode 100644
index 0000000000..64d0d8c7e7
--- /dev/null
+++ b/api/tests/unit_tests/core/variables/test_segment_type.py
@@ -0,0 +1,60 @@
+from core.variables.types import SegmentType
+
+
+class TestSegmentTypeIsArrayType:
+ """
+ Test class for SegmentType.is_array_type method.
+
+ Provides comprehensive coverage of all SegmentType values to ensure
+ correct identification of array and non-array types.
+ """
+
+ def test_is_array_type(self):
+ """
+ Test that all SegmentType enum values are covered in our test cases.
+
+ Ensures comprehensive coverage by verifying that every SegmentType
+ value is tested for the is_array_type method.
+ """
+ # Arrange
+ all_segment_types = set(SegmentType)
+ expected_array_types = [
+ SegmentType.ARRAY_ANY,
+ SegmentType.ARRAY_STRING,
+ SegmentType.ARRAY_NUMBER,
+ SegmentType.ARRAY_OBJECT,
+ SegmentType.ARRAY_FILE,
+ ]
+ expected_non_array_types = [
+ SegmentType.INTEGER,
+ SegmentType.FLOAT,
+ SegmentType.NUMBER,
+ SegmentType.STRING,
+ SegmentType.OBJECT,
+ SegmentType.SECRET,
+ SegmentType.FILE,
+ SegmentType.NONE,
+ SegmentType.GROUP,
+ ]
+
+ for seg_type in expected_array_types:
+ assert seg_type.is_array_type()
+
+ for seg_type in expected_non_array_types:
+ assert not seg_type.is_array_type()
+
+ # Act & Assert
+ covered_types = set(expected_array_types) | set(expected_non_array_types)
+ assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
+
+ def test_all_enum_values_are_supported(self):
+ """
+ Test that all enum values are supported and return boolean values.
+
+ Validates that every SegmentType enum value can be processed by
+ is_array_type method and returns a boolean value.
+ """
+ enum_values: list[SegmentType] = list(SegmentType)
+ for seg_type in enum_values:
+ is_array = seg_type.is_array_type()
+ assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py
index 426557c716..925142892c 100644
--- a/api/tests/unit_tests/core/variables/test_variables.py
+++ b/api/tests/unit_tests/core/variables/test_variables.py
@@ -11,6 +11,7 @@ from core.variables import (
SegmentType,
StringVariable,
)
+from core.variables.variables import Variable
def test_frozen_variables():
@@ -75,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object():
- var = StringVariable(name="text", value="text")
+ var: Variable = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py
new file mode 100644
index 0000000000..cf7cee8710
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py
@@ -0,0 +1,146 @@
+import time
+from decimal import Decimal
+
+from core.model_runtime.entities.llm_entities import LLMUsage
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
+from core.workflow.system_variable import SystemVariable
+
+
+def create_test_graph_runtime_state() -> GraphRuntimeState:
+ """Factory function to create a GraphRuntimeState with non-empty values for testing."""
+ # Create a variable pool with system variables
+ system_vars = SystemVariable(
+ user_id="test_user_123",
+ app_id="test_app_456",
+ workflow_id="test_workflow_789",
+ workflow_execution_id="test_execution_001",
+ query="test query",
+ conversation_id="test_conv_123",
+ dialogue_count=5,
+ )
+ variable_pool = VariablePool(system_variables=system_vars)
+
+ # Add some variables to the variable pool
+ variable_pool.add(["test_node", "test_var"], "test_value")
+ variable_pool.add(["another_node", "another_var"], 42)
+
+ # Create LLM usage with realistic values
+ llm_usage = LLMUsage(
+ prompt_tokens=150,
+ prompt_unit_price=Decimal("0.001"),
+ prompt_price_unit=Decimal(1000),
+ prompt_price=Decimal("0.15"),
+ completion_tokens=75,
+ completion_unit_price=Decimal("0.002"),
+ completion_price_unit=Decimal(1000),
+ completion_price=Decimal("0.15"),
+ total_tokens=225,
+ total_price=Decimal("0.30"),
+ currency="USD",
+ latency=1.25,
+ )
+
+ # Create runtime route state with some node states
+ node_run_state = RuntimeRouteState()
+ node_state = node_run_state.create_node_state("test_node_1")
+ node_run_state.add_route(node_state.id, "target_node_id")
+
+ return GraphRuntimeState(
+ variable_pool=variable_pool,
+ start_at=time.perf_counter(),
+ total_tokens=100,
+ llm_usage=llm_usage,
+ outputs={
+ "string_output": "test result",
+ "int_output": 42,
+ "float_output": 3.14,
+ "list_output": ["item1", "item2", "item3"],
+ "dict_output": {"key1": "value1", "key2": 123},
+ "nested_dict": {"level1": {"level2": ["nested", "list", 456]}},
+ },
+ node_run_steps=5,
+ node_run_state=node_run_state,
+ )
+
+
+def test_basic_round_trip_serialization():
+ """Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged."""
+ # Create a state with non-empty values
+ original_state = create_test_graph_runtime_state()
+
+ # Serialize to JSON and deserialize back
+ json_data = original_state.model_dump_json()
+ deserialized_state = GraphRuntimeState.model_validate_json(json_data)
+
+ # Core test: ensure the round-trip preserves all values
+ assert deserialized_state == original_state
+
+ # Serialize to JSON and deserialize back
+ dict_data = original_state.model_dump(mode="python")
+ deserialized_state = GraphRuntimeState.model_validate(dict_data)
+ assert deserialized_state == original_state
+
+ # Serialize to JSON and deserialize back
+ dict_data = original_state.model_dump(mode="json")
+ deserialized_state = GraphRuntimeState.model_validate(dict_data)
+ assert deserialized_state == original_state
+
+
+def test_outputs_field_round_trip():
+ """Test the problematic outputs field maintains values through round-trip serialization."""
+ original_state = create_test_graph_runtime_state()
+
+ # Serialize and deserialize
+ json_data = original_state.model_dump_json()
+ deserialized_state = GraphRuntimeState.model_validate_json(json_data)
+
+ # Verify the outputs field specifically maintains its values
+ assert deserialized_state.outputs == original_state.outputs
+ assert deserialized_state == original_state
+
+
+def test_empty_outputs_round_trip():
+ """Test round-trip serialization with empty outputs field."""
+ variable_pool = VariablePool.empty()
+ original_state = GraphRuntimeState(
+ variable_pool=variable_pool,
+ start_at=time.perf_counter(),
+ outputs={}, # Empty outputs
+ )
+
+ json_data = original_state.model_dump_json()
+ deserialized_state = GraphRuntimeState.model_validate_json(json_data)
+
+ assert deserialized_state == original_state
+
+
+def test_llm_usage_round_trip():
+ # Create LLM usage with specific decimal values
+ llm_usage = LLMUsage(
+ prompt_tokens=100,
+ prompt_unit_price=Decimal("0.0015"),
+ prompt_price_unit=Decimal(1000),
+ prompt_price=Decimal("0.15"),
+ completion_tokens=50,
+ completion_unit_price=Decimal("0.003"),
+ completion_price_unit=Decimal(1000),
+ completion_price=Decimal("0.15"),
+ total_tokens=150,
+ total_price=Decimal("0.30"),
+ currency="USD",
+ latency=2.5,
+ )
+
+ json_data = llm_usage.model_dump_json()
+ deserialized = LLMUsage.model_validate_json(json_data)
+ assert deserialized == llm_usage
+
+ dict_data = llm_usage.model_dump(mode="python")
+ deserialized = LLMUsage.model_validate(dict_data)
+ assert deserialized == llm_usage
+
+ dict_data = llm_usage.model_dump(mode="json")
+ deserialized = LLMUsage.model_validate(dict_data)
+ assert deserialized == llm_usage
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py
new file mode 100644
index 0000000000..f3de42479a
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py
@@ -0,0 +1,401 @@
+import json
+import uuid
+from datetime import UTC, datetime
+
+import pytest
+from pydantic import ValidationError
+
+from core.workflow.entities.node_entities import NodeRunResult
+from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
+from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
+
+_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
+
+
+class TestRouteNodeStateSerialization:
+ """Test cases for RouteNodeState Pydantic serialization/deserialization."""
+
+ def _test_route_node_state(self):
+ """Test comprehensive RouteNodeState serialization with all core fields validation."""
+
+ node_run_result = NodeRunResult(
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ inputs={"input_key": "input_value"},
+ outputs={"output_key": "output_value"},
+ )
+
+ node_state = RouteNodeState(
+ node_id="comprehensive_test_node",
+ start_at=_TEST_DATETIME,
+ finished_at=_TEST_DATETIME,
+ status=RouteNodeState.Status.SUCCESS,
+ node_run_result=node_run_result,
+ index=5,
+ paused_at=_TEST_DATETIME,
+ paused_by="user_123",
+ failed_reason="test_reason",
+ )
+ return node_state
+
+ def test_route_node_state_comprehensive_field_validation(self):
+ """Test comprehensive RouteNodeState serialization with all core fields validation."""
+ node_state = self._test_route_node_state()
+ serialized = node_state.model_dump()
+
+ # Comprehensive validation of all RouteNodeState fields
+ assert serialized["node_id"] == "comprehensive_test_node"
+ assert serialized["status"] == RouteNodeState.Status.SUCCESS
+ assert serialized["start_at"] == _TEST_DATETIME
+ assert serialized["finished_at"] == _TEST_DATETIME
+ assert serialized["paused_at"] == _TEST_DATETIME
+ assert serialized["paused_by"] == "user_123"
+ assert serialized["failed_reason"] == "test_reason"
+ assert serialized["index"] == 5
+ assert "id" in serialized
+ assert isinstance(serialized["id"], str)
+ uuid.UUID(serialized["id"]) # Validate UUID format
+
+ # Validate nested NodeRunResult structure
+ assert serialized["node_run_result"] is not None
+ assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"}
+ assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"}
+
+ def test_route_node_state_minimal_required_fields(self):
+ """Test RouteNodeState with only required fields, focusing on defaults."""
+ node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME)
+
+ serialized = node_state.model_dump()
+
+ # Focus on required fields and default values (not re-testing all fields)
+ assert serialized["node_id"] == "minimal_node"
+ assert serialized["start_at"] == _TEST_DATETIME
+ assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status
+ assert serialized["index"] == 1 # Default index
+ assert serialized["node_run_result"] is None # Default None
+ json = node_state.model_dump_json()
+ deserialized = RouteNodeState.model_validate_json(json)
+ assert deserialized == node_state
+
+ def test_route_node_state_deserialization_from_dict(self):
+ """Test RouteNodeState deserialization from dictionary data."""
+ test_datetime = datetime(2024, 1, 15, 10, 30, 45)
+ test_id = str(uuid.uuid4())
+
+ dict_data = {
+ "id": test_id,
+ "node_id": "deserialized_node",
+ "start_at": test_datetime,
+ "status": "success",
+ "finished_at": test_datetime,
+ "index": 3,
+ }
+
+ node_state = RouteNodeState.model_validate(dict_data)
+
+ # Focus on deserialization accuracy
+ assert node_state.id == test_id
+ assert node_state.node_id == "deserialized_node"
+ assert node_state.start_at == test_datetime
+ assert node_state.status == RouteNodeState.Status.SUCCESS
+ assert node_state.finished_at == test_datetime
+ assert node_state.index == 3
+
+ def test_route_node_state_round_trip_consistency(self):
+ node_states = (
+ self._test_route_node_state(),
+ RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME),
+ )
+ for node_state in node_states:
+ json = node_state.model_dump_json()
+ deserialized = RouteNodeState.model_validate_json(json)
+ assert deserialized == node_state
+
+ dict_ = node_state.model_dump(mode="python")
+ deserialized = RouteNodeState.model_validate(dict_)
+ assert deserialized == node_state
+
+ dict_ = node_state.model_dump(mode="json")
+ deserialized = RouteNodeState.model_validate(dict_)
+ assert deserialized == node_state
+
+
+class TestRouteNodeStateEnumSerialization:
+ """Dedicated tests for RouteNodeState Status enum serialization behavior."""
+
+ def test_status_enum_model_dump_behavior(self):
+ """Test Status enum serialization in model_dump() returns enum objects."""
+
+ for status_enum in RouteNodeState.Status:
+ node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum)
+ serialized = node_state.model_dump(mode="python")
+ assert serialized["status"] == status_enum
+ serialized = node_state.model_dump(mode="json")
+ assert serialized["status"] == status_enum.value
+
+ def test_status_enum_json_serialization_behavior(self):
+ """Test Status enum serialization in JSON returns string values."""
+ test_datetime = datetime(2024, 1, 15, 10, 30, 45)
+
+ enum_to_string_mapping = {
+ RouteNodeState.Status.RUNNING: "running",
+ RouteNodeState.Status.SUCCESS: "success",
+ RouteNodeState.Status.FAILED: "failed",
+ RouteNodeState.Status.PAUSED: "paused",
+ RouteNodeState.Status.EXCEPTION: "exception",
+ }
+
+ for status_enum, expected_string in enum_to_string_mapping.items():
+ node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum)
+
+ json_data = json.loads(node_state.model_dump_json())
+ assert json_data["status"] == expected_string
+
+ def test_status_enum_deserialization_from_string(self):
+ """Test Status enum deserialization from string values."""
+ test_datetime = datetime(2024, 1, 15, 10, 30, 45)
+
+ string_to_enum_mapping = {
+ "running": RouteNodeState.Status.RUNNING,
+ "success": RouteNodeState.Status.SUCCESS,
+ "failed": RouteNodeState.Status.FAILED,
+ "paused": RouteNodeState.Status.PAUSED,
+ "exception": RouteNodeState.Status.EXCEPTION,
+ }
+
+ for status_string, expected_enum in string_to_enum_mapping.items():
+ dict_data = {
+ "node_id": "enum_deserialize_test",
+ "start_at": test_datetime,
+ "status": status_string,
+ }
+
+ node_state = RouteNodeState.model_validate(dict_data)
+ assert node_state.status == expected_enum
+
+
+class TestRuntimeRouteStateSerialization:
+ """Test cases for RuntimeRouteState Pydantic serialization/deserialization."""
+
+ _NODE1_ID = "node_1"
+ _ROUTE_STATE1_ID = str(uuid.uuid4())
+ _NODE2_ID = "node_2"
+ _ROUTE_STATE2_ID = str(uuid.uuid4())
+ _NODE3_ID = "node_3"
+ _ROUTE_STATE3_ID = str(uuid.uuid4())
+
+ def _get_runtime_route_state(self):
+ # Create node states with different configurations
+ node_state_1 = RouteNodeState(
+ id=self._ROUTE_STATE1_ID,
+ node_id=self._NODE1_ID,
+ start_at=_TEST_DATETIME,
+ index=1,
+ )
+ node_state_2 = RouteNodeState(
+ id=self._ROUTE_STATE2_ID,
+ node_id=self._NODE2_ID,
+ start_at=_TEST_DATETIME,
+ status=RouteNodeState.Status.SUCCESS,
+ finished_at=_TEST_DATETIME,
+ index=2,
+ )
+ node_state_3 = RouteNodeState(
+ id=self._ROUTE_STATE3_ID,
+ node_id=self._NODE3_ID,
+ start_at=_TEST_DATETIME,
+ status=RouteNodeState.Status.FAILED,
+ failed_reason="Test failure",
+ index=3,
+ )
+
+ runtime_state = RuntimeRouteState(
+ routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]},
+ node_state_mapping={
+ node_state_1.id: node_state_1,
+ node_state_2.id: node_state_2,
+ node_state_3.id: node_state_3,
+ },
+ )
+
+ return runtime_state
+
+ def test_runtime_route_state_comprehensive_structure_validation(self):
+ """Test comprehensive RuntimeRouteState serialization with full structure validation."""
+
+ runtime_state = self._get_runtime_route_state()
+ serialized = runtime_state.model_dump()
+
+ # Comprehensive validation of RuntimeRouteState structure
+ assert "routes" in serialized
+ assert "node_state_mapping" in serialized
+ assert isinstance(serialized["routes"], dict)
+ assert isinstance(serialized["node_state_mapping"], dict)
+
+ # Validate routes dictionary structure and content
+ assert len(serialized["routes"]) == 2
+ assert self._ROUTE_STATE1_ID in serialized["routes"]
+ assert self._ROUTE_STATE2_ID in serialized["routes"]
+ assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID]
+ assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID]
+
+ # Validate node_state_mapping dictionary structure and content
+ assert len(serialized["node_state_mapping"]) == 3
+ for state_id in [
+ self._ROUTE_STATE1_ID,
+ self._ROUTE_STATE2_ID,
+ self._ROUTE_STATE3_ID,
+ ]:
+ assert state_id in serialized["node_state_mapping"]
+ node_data = serialized["node_state_mapping"][state_id]
+ node_state = runtime_state.node_state_mapping[state_id]
+ assert node_data["node_id"] == node_state.node_id
+ assert node_data["status"] == node_state.status
+ assert node_data["index"] == node_state.index
+
+ def test_runtime_route_state_empty_collections(self):
+ """Test RuntimeRouteState with empty collections, focusing on default behavior."""
+ runtime_state = RuntimeRouteState()
+ serialized = runtime_state.model_dump()
+
+ # Focus on default empty collection behavior
+ assert serialized["routes"] == {}
+ assert serialized["node_state_mapping"] == {}
+ assert isinstance(serialized["routes"], dict)
+ assert isinstance(serialized["node_state_mapping"], dict)
+
+ def test_runtime_route_state_json_serialization_structure(self):
+ """Test RuntimeRouteState JSON serialization structure."""
+ node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME)
+
+ runtime_state = RuntimeRouteState(
+ routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state}
+ )
+
+ json_str = runtime_state.model_dump_json()
+ json_data = json.loads(json_str)
+
+ # Focus on JSON structure validation
+ assert isinstance(json_str, str)
+ assert isinstance(json_data, dict)
+ assert "routes" in json_data
+ assert "node_state_mapping" in json_data
+ assert json_data["routes"]["source"] == ["target1", "target2"]
+ assert node_state.id in json_data["node_state_mapping"]
+
+ def test_runtime_route_state_deserialization_from_dict(self):
+ """Test RuntimeRouteState deserialization from dictionary data."""
+ node_id = str(uuid.uuid4())
+
+ dict_data = {
+ "routes": {"source_node": ["target_node_1", "target_node_2"]},
+ "node_state_mapping": {
+ node_id: {
+ "id": node_id,
+ "node_id": "test_node",
+ "start_at": _TEST_DATETIME,
+ "status": "running",
+ "index": 1,
+ }
+ },
+ }
+
+ runtime_state = RuntimeRouteState.model_validate(dict_data)
+
+ # Focus on deserialization accuracy
+ assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]}
+ assert len(runtime_state.node_state_mapping) == 1
+ assert node_id in runtime_state.node_state_mapping
+
+ deserialized_node = runtime_state.node_state_mapping[node_id]
+ assert deserialized_node.node_id == "test_node"
+ assert deserialized_node.status == RouteNodeState.Status.RUNNING
+ assert deserialized_node.index == 1
+
+ def test_runtime_route_state_round_trip_consistency(self):
+ """Test RuntimeRouteState round-trip serialization consistency."""
+ original = self._get_runtime_route_state()
+
+ # Dictionary round trip
+ dict_data = original.model_dump(mode="python")
+ reconstructed = RuntimeRouteState.model_validate(dict_data)
+ assert reconstructed == original
+
+ dict_data = original.model_dump(mode="json")
+ reconstructed = RuntimeRouteState.model_validate(dict_data)
+ assert reconstructed == original
+
+ # JSON round trip
+ json_str = original.model_dump_json()
+ json_reconstructed = RuntimeRouteState.model_validate_json(json_str)
+ assert json_reconstructed == original
+
+
+class TestSerializationEdgeCases:
+ """Test edge cases and error conditions for serialization/deserialization."""
+
+ def test_invalid_status_deserialization(self):
+ """Test deserialization with invalid status values."""
+ test_datetime = _TEST_DATETIME
+ invalid_data = {
+ "node_id": "invalid_test",
+ "start_at": test_datetime,
+ "status": "invalid_status",
+ }
+
+ with pytest.raises(ValidationError) as exc_info:
+ RouteNodeState.model_validate(invalid_data)
+ assert "status" in str(exc_info.value)
+
+ def test_missing_required_fields_deserialization(self):
+ """Test deserialization with missing required fields."""
+ incomplete_data = {"id": str(uuid.uuid4())}
+
+ with pytest.raises(ValidationError) as exc_info:
+ RouteNodeState.model_validate(incomplete_data)
+ error_str = str(exc_info.value)
+ assert "node_id" in error_str or "start_at" in error_str
+
+ def test_invalid_datetime_deserialization(self):
+ """Test deserialization with invalid datetime values."""
+ invalid_data = {
+ "node_id": "datetime_test",
+ "start_at": "invalid_datetime",
+ "status": "running",
+ }
+
+ with pytest.raises(ValidationError) as exc_info:
+ RouteNodeState.model_validate(invalid_data)
+ assert "start_at" in str(exc_info.value)
+
+ def test_invalid_routes_structure_deserialization(self):
+ """Test RuntimeRouteState deserialization with invalid routes structure."""
+ invalid_data = {
+ "routes": "invalid_routes_structure", # Should be dict
+ "node_state_mapping": {},
+ }
+
+ with pytest.raises(ValidationError) as exc_info:
+ RuntimeRouteState.model_validate(invalid_data)
+ assert "routes" in str(exc_info.value)
+
+ def test_timezone_handling_in_datetime_fields(self):
+ """Test timezone handling in datetime field serialization."""
+ utc_datetime = datetime.now(UTC)
+ naive_datetime = utc_datetime.replace(tzinfo=None)
+
+ node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime)
+ dict_ = node_state.model_dump()
+
+ assert dict_["start_at"] == naive_datetime
+
+ # Test round trip
+ reconstructed = RouteNodeState.model_validate(dict_)
+ assert reconstructed.start_at == naive_datetime
+ assert reconstructed.start_at.tzinfo is None
+
+ json = node_state.model_dump_json()
+
+ reconstructed = RouteNodeState.model_validate_json(json)
+ assert reconstructed.start_at == naive_datetime
+ assert reconstructed.start_at.tzinfo is None
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py
index c288a5fa13..ed4e42425e 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py
@@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
@@ -27,6 +26,7 @@ from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -171,7 +171,8 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
+ system_variables=SystemVariable(user_id="aaa", app_id="1", workflow_id="1", files=[]),
+ user_inputs={"query": "hi"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
@@ -293,12 +294,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "what's the weather in SF",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa",
+ files=[],
+ query="what's the weather in SF",
+ conversation_id="abababa",
+ ),
user_inputs={},
)
@@ -474,12 +475,12 @@ def test_run_branch(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "hi",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa",
+ files=[],
+ query="hi",
+ conversation_id="abababa",
+ ),
user_inputs={"uid": "takato"},
)
@@ -804,18 +805,22 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
# construct variable pool
pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "dify",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "1",
- },
+ system_variables=SystemVariable(
+ user_id="1",
+ files=[],
+ query="dify",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
+ system_variables=SystemVariable(
+ user_id="aaa",
+ files=[],
+ ),
+ user_inputs={"query": "hi"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
index b7f78d91fa..1ef024f46b 100644
--- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
+++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
@@ -5,11 +5,11 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
+from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -51,28 +51,33 @@ def test_execute_answer():
# construct variable pool
pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
)
pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.")
+ node_config = {
+ "id": "answer",
+ "data": {
+ "title": "123",
+ "type": "answer",
+ "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
+ },
+ }
+
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "id": "answer",
- "data": {
- "title": "123",
- "type": "answer",
- "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Mock db.session.close()
db.session.close = MagicMock()
diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py
index c3a3818655..137e8b889d 100644
--- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py
+++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py
@@ -3,7 +3,6 @@ from collections.abc import Generator
from datetime import UTC, datetime
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
@@ -15,6 +14,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
+from core.workflow.system_variable import SystemVariable
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
@@ -180,12 +180,12 @@ def test_process():
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "what's the weather in SF",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa",
+ files=[],
+ query="what's the weather in SF",
+ conversation_id="abababa",
+ ),
user_inputs={},
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py
index d066fc1e33..bb6d72f51e 100644
--- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py
+++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py
@@ -7,12 +7,13 @@ from core.workflow.nodes.http_request import (
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.executor import Executor
+from core.workflow.system_variable import SystemVariable
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "number"], 42)
@@ -65,7 +66,7 @@ def test_executor_with_json_body_and_number_variable():
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
@@ -120,7 +121,7 @@ def test_executor_with_json_body_and_object_variable():
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
@@ -174,7 +175,7 @@ def test_executor_with_json_body_and_nested_object_variable():
def test_extract_selectors_from_template_with_newline():
- variable_pool = VariablePool()
+ variable_pool = VariablePool(system_variables=SystemVariable.empty())
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
@@ -201,7 +202,7 @@ def test_extract_selectors_from_template_with_newline():
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
@@ -280,7 +281,11 @@ def test_init_headers():
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
- return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
+ return Executor(
+ node_data=node_data,
+ timeout=timeout,
+ variable_pool=VariablePool(system_variables=SystemVariable.empty()),
+ )
executor = create_executor("aa\n cc:")
executor._init_headers()
@@ -310,7 +315,11 @@ def test_init_params():
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
- return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
+ return Executor(
+ node_data=node_data,
+ timeout=timeout,
+ variable_pool=VariablePool(system_variables=SystemVariable.empty()),
+ )
# Test basic key-value pairs
executor = create_executor("key1:value1\nkey2:value2")
diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
index 7fd32a4826..71b3a8f7d8 100644
--- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
@@ -15,6 +15,7 @@ from core.workflow.nodes.http_request import (
HttpRequestNodeBody,
HttpRequestNodeData,
)
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -40,7 +41,7 @@ def test_http_request_node_binary_file(monkeypatch):
),
)
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(
@@ -56,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch):
),
),
)
+
+ node_config = {
+ "id": "1",
+ "data": data.model_dump(),
+ }
+
node = HttpRequestNode(
id="1",
- config={
- "id": "1",
- "data": data.model_dump(),
- },
+ config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -89,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch):
start_at=0,
),
)
+
+ # Initialize node data
+ node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
@@ -128,7 +135,7 @@ def test_http_request_node_form_with_file(monkeypatch):
),
)
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(
@@ -144,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch):
),
),
)
+
+ node_config = {
+ "id": "1",
+ "data": data.model_dump(),
+ }
+
node = HttpRequestNode(
id="1",
- config={
- "id": "1",
- "data": data.model_dump(),
- },
+ config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -177,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch):
start_at=0,
),
)
+
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
@@ -223,7 +237,7 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
)
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
@@ -256,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
),
)
+ node_config = {
+ "id": "1",
+ "data": data.model_dump(),
+ }
+
node = HttpRequestNode(
id="1",
- config={
- "id": "1",
- "data": data.model_dump(),
- },
+ config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -290,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
),
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py
index 362072a3db..f53f391433 100644
--- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py
+++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py
@@ -7,7 +7,6 @@ from core.variables.segments import ArrayAnySegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@@ -15,6 +14,7 @@ from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -151,36 +151,41 @@ def test_run():
# construct variable pool
pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "dify",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "1",
- },
+ system_variables=SystemVariable(
+ user_id="1",
+ files=[],
+ query="dify",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
+ node_config = {
+ "data": {
+ "iterator_selector": ["pe", "list_output"],
+ "output_selector": ["tt", "output"],
+ "output_type": "array[string]",
+ "startNodeType": "template-transform",
+ "start_node_id": "tt",
+ "title": "迭代",
+ "type": "iteration",
+ },
+ "id": "iteration-1",
+ }
+
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "data": {
- "iterator_selector": ["pe", "list_output"],
- "output_selector": ["tt", "output"],
- "output_type": "array[string]",
- "startNodeType": "template-transform",
- "start_node_id": "tt",
- "title": "迭代",
- "type": "iteration",
- },
- "id": "iteration-1",
- },
+ config=node_config,
)
+ # Initialize node data
+ iteration_node.init_node_data(node_config["data"])
+
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -368,36 +373,41 @@ def test_run_parallel():
# construct variable pool
pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "dify",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "1",
- },
+ system_variables=SystemVariable(
+ user_id="1",
+ files=[],
+ query="dify",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
+ node_config = {
+ "data": {
+ "iterator_selector": ["pe", "list_output"],
+ "output_selector": ["tt", "output"],
+ "output_type": "array[string]",
+ "startNodeType": "template-transform",
+ "start_node_id": "iteration-start",
+ "title": "迭代",
+ "type": "iteration",
+ },
+ "id": "iteration-1",
+ }
+
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "data": {
- "iterator_selector": ["pe", "list_output"],
- "output_selector": ["tt", "output"],
- "output_type": "array[string]",
- "startNodeType": "template-transform",
- "start_node_id": "iteration-start",
- "title": "迭代",
- "type": "iteration",
- },
- "id": "iteration-1",
- },
+ config=node_config,
)
+ # Initialize node data
+ iteration_node.init_node_data(node_config["data"])
+
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -584,56 +594,66 @@ def test_iteration_run_in_parallel_mode():
# construct variable pool
pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "dify",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "1",
- },
+ system_variables=SystemVariable(
+ user_id="1",
+ files=[],
+ query="dify",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
+ parallel_node_config = {
+ "data": {
+ "iterator_selector": ["pe", "list_output"],
+ "output_selector": ["tt", "output"],
+ "output_type": "array[string]",
+ "startNodeType": "template-transform",
+ "start_node_id": "iteration-start",
+ "title": "迭代",
+ "type": "iteration",
+ "is_parallel": True,
+ },
+ "id": "iteration-1",
+ }
+
parallel_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "data": {
- "iterator_selector": ["pe", "list_output"],
- "output_selector": ["tt", "output"],
- "output_type": "array[string]",
- "startNodeType": "template-transform",
- "start_node_id": "iteration-start",
- "title": "迭代",
- "type": "iteration",
- "is_parallel": True,
- },
- "id": "iteration-1",
- },
+ config=parallel_node_config,
)
+
+ # Initialize node data
+ parallel_iteration_node.init_node_data(parallel_node_config["data"])
+ sequential_node_config = {
+ "data": {
+ "iterator_selector": ["pe", "list_output"],
+ "output_selector": ["tt", "output"],
+ "output_type": "array[string]",
+ "startNodeType": "template-transform",
+ "start_node_id": "iteration-start",
+ "title": "迭代",
+ "type": "iteration",
+ "is_parallel": True,
+ },
+ "id": "iteration-1",
+ }
+
sequential_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "data": {
- "iterator_selector": ["pe", "list_output"],
- "output_selector": ["tt", "output"],
- "output_type": "array[string]",
- "startNodeType": "template-transform",
- "start_node_id": "iteration-start",
- "title": "迭代",
- "type": "iteration",
- "is_parallel": True,
- },
- "id": "iteration-1",
- },
+ config=sequential_node_config,
)
+ # Initialize node data
+ sequential_iteration_node.init_node_data(sequential_node_config["data"])
+
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -645,8 +665,8 @@ def test_iteration_run_in_parallel_mode():
# execute node
parallel_result = parallel_iteration_node._run()
sequential_result = sequential_iteration_node._run()
- assert parallel_iteration_node.node_data.parallel_nums == 10
- assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED
+ assert parallel_iteration_node._node_data.parallel_nums == 10
+ assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
count = 0
parallel_arr = []
sequential_arr = []
@@ -808,36 +828,41 @@ def test_iteration_run_error_handle():
# construct variable pool
pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "dify",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "1",
- },
+ system_variables=SystemVariable(
+ user_id="1",
+ files=[],
+ query="dify",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["1", "1"])
+ error_node_config = {
+ "data": {
+ "iterator_selector": ["pe", "list_output"],
+ "output_selector": ["tt", "output"],
+ "output_type": "array[string]",
+ "startNodeType": "template-transform",
+ "start_node_id": "iteration-start",
+ "title": "iteration",
+ "type": "iteration",
+ "is_parallel": True,
+ "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
+ },
+ "id": "iteration-1",
+ }
+
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "data": {
- "iterator_selector": ["pe", "list_output"],
- "output_selector": ["tt", "output"],
- "output_type": "array[string]",
- "startNodeType": "template-transform",
- "start_node_id": "iteration-start",
- "title": "iteration",
- "type": "iteration",
- "is_parallel": True,
- "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
- },
- "id": "iteration-1",
- },
+ config=error_node_config,
)
+
+ # Initialize node data
+ iteration_node.init_node_data(error_node_config["data"])
# execute continue on error node
result = iteration_node._run()
result_arr = []
@@ -851,7 +876,7 @@ def test_iteration_run_error_handle():
assert count == 14
# execute remove abnormal output
- iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
+ iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
result = iteration_node._run()
count = 0
for item in result:
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
index 336c2befcc..23a7fab7cf 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
@@ -36,6 +36,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType
@@ -104,7 +105,7 @@ def graph() -> Graph:
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
return GraphRuntimeState(
@@ -118,17 +119,20 @@ def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
+ node_config = {
+ "id": "1",
+ "data": llm_node_data.model_dump(),
+ }
node = LLMNode(
id="1",
- config={
- "id": "1",
- "data": llm_node_data.model_dump(),
- },
+ config=node_config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
return node
@@ -181,7 +185,7 @@ def test_fetch_files_with_file_segment():
related_id="1",
storage_key="",
)
- variable_pool = VariablePool()
+ variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], file)
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -209,7 +213,7 @@ def test_fetch_files_with_array_file_segment():
storage_key="",
),
]
- variable_pool = VariablePool()
+ variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -217,7 +221,7 @@ def test_fetch_files_with_array_file_segment():
def test_fetch_files_with_none_segment():
- variable_pool = VariablePool()
+ variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], NoneSegment())
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -225,7 +229,7 @@ def test_fetch_files_with_none_segment():
def test_fetch_files_with_array_any_segment():
- variable_pool = VariablePool()
+ variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -233,7 +237,7 @@ def test_fetch_files_with_array_any_segment():
def test_fetch_files_with_non_existent_variable():
- variable_pool = VariablePool()
+ variable_pool = VariablePool.empty()
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []
@@ -487,7 +491,7 @@ def test_handle_list_messages_basic(llm_node):
variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
- result = llm_node._handle_list_messages(
+ result = llm_node.handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
@@ -505,17 +509,20 @@ def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
+ node_config = {
+ "id": "1",
+ "data": llm_node_data.model_dump(),
+ }
node = LLMNode(
id="1",
- config={
- "id": "1",
- "data": llm_node_data.model_dump(),
- },
+ config=node_config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
return node, mock_file_saver
@@ -539,7 +546,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_binary_string.return_value = mock_file
- file = llm_node._save_multimodal_image_output(content=content)
+ file = llm_node.save_multimodal_image_output(
+ content=content,
+ file_saver=mock_file_saver,
+ )
+ # Manually append to _file_outputs since the static method doesn't do it
+ llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with(
@@ -565,7 +577,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_remote_url.return_value = mock_file
- file = llm_node._save_multimodal_image_output(content=content)
+ file = llm_node.save_multimodal_image_output(
+ content=content,
+ file_saver=mock_file_saver,
+ )
+ # Manually append to _file_outputs since the static method doesn't do it
+ llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
@@ -581,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
- gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
+ gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
+ contents="hello world", file_saver=mock_file_saver, file_outputs=[]
+ )
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
@@ -589,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
- [TextPromptMessageContent(data="hello world")]
+ contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
@@ -615,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
)
mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
- [
+ contents=[
ImagePromptMessageContent(
format="png",
base64_data=image_b64_data,
mime_type="image/png",
)
- ]
+ ],
+ file_saver=mock_file_saver,
+ file_outputs=llm_node._file_outputs,
)
yielded_strs = list(gen)
assert len(yielded_strs) == 1
@@ -644,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
- gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
+ gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
+ contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[]
+ )
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
- gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
+ gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
+ contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[]
+ )
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
- gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
+ gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
+ contents=None, file_saver=mock_file_saver, file_outputs=[]
+ )
assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py
index abc822e98b..466d7bad06 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py
@@ -5,11 +5,11 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
+from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -53,7 +53,7 @@ def test_execute_answer():
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -61,21 +61,26 @@ def test_execute_answer():
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
+ node_config = {
+ "id": "answer",
+ "data": {
+ "title": "123",
+ "type": "answer",
+ "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
+ },
+ }
+
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "answer",
- "data": {
- "title": "123",
- "type": "answer",
- "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Mock db.session.close()
db.session.close = MagicMock()
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py
index a6c553faf0..3f83428834 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py
@@ -5,7 +5,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
NodeRunExceptionEvent,
@@ -17,6 +16,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -167,12 +167,12 @@ class ContinueOnErrorTestHelper:
"""Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "clear",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa",
+ files=[],
+ query="clear",
+ conversation_id="abababa",
+ ),
user_inputs=user_inputs or {"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
index 66c7818adf..486ae51e5f 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
@@ -27,13 +27,17 @@ def document_extractor_node():
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
- return DocumentExtractorNode(
+ node_config = {"id": "test_node_id", "data": node_data.model_dump()}
+ node = DocumentExtractorNode(
id="test_node_id",
- config={"id": "test_node_id", "data": node_data.model_dump()},
+ config=node_config,
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+ return node
@pytest.fixture
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
index c4e411f9d6..8383aee0e4 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
@@ -7,12 +7,12 @@ from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.nodes.if_else.if_else_node import IfElseNode
+from core.workflow.system_variable import SystemVariable
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
from extensions.ext_database import db
from models.enums import UserFrom
@@ -37,9 +37,7 @@ def test_execute_if_else_result_true():
)
# construct variable pool
- pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={}
- )
+ pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={})
pool.add(["start", "array_contains"], ["ab", "def"])
pool.add(["start", "array_not_contains"], ["ac", "def"])
pool.add(["start", "contains"], "cabcde")
@@ -59,57 +57,62 @@ def test_execute_if_else_result_true():
pool.add(["start", "null"], None)
pool.add(["start", "not_null"], "1212")
+ node_config = {
+ "id": "if-else",
+ "data": {
+ "title": "123",
+ "type": "if-else",
+ "logical_operator": "and",
+ "conditions": [
+ {
+ "comparison_operator": "contains",
+ "variable_selector": ["start", "array_contains"],
+ "value": "ab",
+ },
+ {
+ "comparison_operator": "not contains",
+ "variable_selector": ["start", "array_not_contains"],
+ "value": "ab",
+ },
+ {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
+ {
+ "comparison_operator": "not contains",
+ "variable_selector": ["start", "not_contains"],
+ "value": "ab",
+ },
+ {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
+ {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
+ {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
+ {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
+ {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
+ {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
+ {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
+ {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"},
+ {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
+ {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
+ {
+ "comparison_operator": "≥",
+ "variable_selector": ["start", "greater_than_or_equal"],
+ "value": "22",
+ },
+ {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
+ {"comparison_operator": "null", "variable_selector": ["start", "null"]},
+ {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
+ ],
+ },
+ }
+
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "id": "if-else",
- "data": {
- "title": "123",
- "type": "if-else",
- "logical_operator": "and",
- "conditions": [
- {
- "comparison_operator": "contains",
- "variable_selector": ["start", "array_contains"],
- "value": "ab",
- },
- {
- "comparison_operator": "not contains",
- "variable_selector": ["start", "array_not_contains"],
- "value": "ab",
- },
- {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
- {
- "comparison_operator": "not contains",
- "variable_selector": ["start", "not_contains"],
- "value": "ab",
- },
- {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
- {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
- {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
- {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
- {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
- {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
- {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
- {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"},
- {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
- {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
- {
- "comparison_operator": "≥",
- "variable_selector": ["start", "greater_than_or_equal"],
- "value": "22",
- },
- {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
- {"comparison_operator": "null", "variable_selector": ["start", "null"]},
- {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Mock db.session.close()
db.session.close = MagicMock()
@@ -157,40 +160,45 @@ def test_execute_if_else_result_false():
# construct variable pool
pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
)
pool.add(["start", "array_contains"], ["1ab", "def"])
pool.add(["start", "array_not_contains"], ["ab", "def"])
+ node_config = {
+ "id": "if-else",
+ "data": {
+ "title": "123",
+ "type": "if-else",
+ "logical_operator": "or",
+ "conditions": [
+ {
+ "comparison_operator": "contains",
+ "variable_selector": ["start", "array_contains"],
+ "value": "ab",
+ },
+ {
+ "comparison_operator": "not contains",
+ "variable_selector": ["start", "array_not_contains"],
+ "value": "ab",
+ },
+ ],
+ },
+ }
+
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "id": "if-else",
- "data": {
- "title": "123",
- "type": "if-else",
- "logical_operator": "or",
- "conditions": [
- {
- "comparison_operator": "contains",
- "variable_selector": ["start", "array_contains"],
- "value": "ab",
- },
- {
- "comparison_operator": "not contains",
- "variable_selector": ["start", "array_not_contains"],
- "value": "ab",
- },
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Mock db.session.close()
db.session.close = MagicMock()
@@ -230,17 +238,22 @@ def test_array_file_contains_file_name():
],
)
+ node_config = {
+ "id": "if-else",
+ "data": node_data.model_dump(),
+ }
+
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
- config={
- "id": "if-else",
- "data": node_data.model_dump(),
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
index 7d3a1d6a2d..5fc9eab2df 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
@@ -33,16 +33,19 @@ def list_operator_node():
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)
+ node_config = {
+ "id": "test_node_id",
+ "data": node_data.model_dump(),
+ }
node = ListOperatorNode(
id="test_node_id",
- config={
- "id": "test_node_id",
- "data": node_data.model_dump(),
- },
+ config=node_config,
graph_init_params=MagicMock(),
graph=MagicMock(),
graph_runtime_state=MagicMock(),
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node
diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
index e121f6338c..0eaabd0c40 100644
--- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
@@ -15,6 +15,7 @@ from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.tool.entities import ToolNodeData
+from core.workflow.system_variable import SystemVariable
from models import UserFrom, WorkflowType
@@ -34,15 +35,16 @@ def _create_tool_node():
version="1",
)
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
+ node_config = {
+ "id": "1",
+ "data": data.model_dump(),
+ }
node = ToolNode(
id="1",
- config={
- "id": "1",
- "data": data.model_dump(),
- },
+ config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -70,6 +72,8 @@ def _create_tool_node():
start_at=0,
),
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
return node
diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py
index deb3e29b86..ee51339427 100644
--- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py
+++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py
@@ -7,12 +7,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -68,7 +68,7 @@ def test_overwrite_string_variable():
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
+ system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -82,23 +82,28 @@ def test_overwrite_string_variable():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "assigned_variable_selector": ["conversation", conversation_variable.name],
+ "write_mode": WriteMode.OVER_WRITE.value,
+ "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "assigned_variable_selector": ["conversation", conversation_variable.name],
- "write_mode": WriteMode.OVER_WRITE.value,
- "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
- },
- },
+ config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
@@ -165,7 +170,7 @@ def test_append_variable_to_array():
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
+ system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -178,23 +183,28 @@ def test_append_variable_to_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "assigned_variable_selector": ["conversation", conversation_variable.name],
+ "write_mode": WriteMode.APPEND.value,
+ "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "assigned_variable_selector": ["conversation", conversation_variable.name],
- "write_mode": WriteMode.APPEND.value,
- "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
- },
- },
+ config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
@@ -256,7 +266,7 @@ def test_clear_array():
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
+ system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -265,23 +275,28 @@ def test_clear_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "assigned_variable_selector": ["conversation", conversation_variable.name],
+ "write_mode": WriteMode.CLEAR.value,
+ "input_variable_selector": [],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "assigned_variable_selector": ["conversation", conversation_variable.name],
- "write_mode": WriteMode.CLEAR.value,
- "input_variable_selector": [],
- },
- },
+ config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,
diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py
index 7c5597dd89..987eaf7534 100644
--- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py
+++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py
@@ -5,12 +5,12 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -109,34 +109,39 @@ def test_remove_first_from_array():
)
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
+ system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "version": "2",
+ "items": [
+ {
+ "variable_selector": ["conversation", conversation_variable.name],
+ "input_type": InputType.VARIABLE,
+ "operation": Operation.REMOVE_FIRST,
+ "value": None,
+ }
+ ],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "version": "2",
- "items": [
- {
- "variable_selector": ["conversation", conversation_variable.name],
- "input_type": InputType.VARIABLE,
- "operation": Operation.REMOVE_FIRST,
- "value": None,
- }
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Skip the mock assertion since we're in a test environment
# Print the variable before running
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
@@ -196,34 +201,39 @@ def test_remove_last_from_array():
)
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
+ system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "version": "2",
+ "items": [
+ {
+ "variable_selector": ["conversation", conversation_variable.name],
+ "input_type": InputType.VARIABLE,
+ "operation": Operation.REMOVE_LAST,
+ "value": None,
+ }
+ ],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "version": "2",
- "items": [
- {
- "variable_selector": ["conversation", conversation_variable.name],
- "input_type": InputType.VARIABLE,
- "operation": Operation.REMOVE_LAST,
- "value": None,
- }
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Skip the mock assertion since we're in a test environment
list(node.run())
@@ -275,34 +285,39 @@ def test_remove_first_from_empty_array():
)
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
+ system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "version": "2",
+ "items": [
+ {
+ "variable_selector": ["conversation", conversation_variable.name],
+ "input_type": InputType.VARIABLE,
+ "operation": Operation.REMOVE_FIRST,
+ "value": None,
+ }
+ ],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "version": "2",
- "items": [
- {
- "variable_selector": ["conversation", conversation_variable.name],
- "input_type": InputType.VARIABLE,
- "operation": Operation.REMOVE_FIRST,
- "value": None,
- }
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Skip the mock assertion since we're in a test environment
list(node.run())
@@ -354,34 +369,39 @@ def test_remove_last_from_empty_array():
)
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
+ system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "version": "2",
+ "items": [
+ {
+ "variable_selector": ["conversation", conversation_variable.name],
+ "input_type": InputType.VARIABLE,
+ "operation": Operation.REMOVE_LAST,
+ "value": None,
+ }
+ ],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "version": "2",
- "items": [
- {
- "variable_selector": ["conversation", conversation_variable.name],
- "input_type": InputType.VARIABLE,
- "operation": Operation.REMOVE_LAST,
- "value": None,
- }
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Skip the mock assertion since we're in a test environment
list(node.run())
diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py
new file mode 100644
index 0000000000..11d788ed79
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/test_system_variable.py
@@ -0,0 +1,251 @@
+import json
+from typing import Any
+
+import pytest
+from pydantic import ValidationError
+
+from core.file.enums import FileTransferMethod, FileType
+from core.file.models import File
+from core.workflow.system_variable import SystemVariable
+
+# Test data constants for SystemVariable serialization tests
+VALID_BASE_DATA: dict[str, Any] = {
+ "user_id": "a20f06b1-8703-45ab-937c-860a60072113",
+ "app_id": "661bed75-458d-49c9-b487-fda0762677b9",
+ "workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43",
+}
+
+COMPLETE_VALID_DATA: dict[str, Any] = {
+ **VALID_BASE_DATA,
+ "query": "test query",
+ "files": [],
+ "conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9",
+ "dialogue_count": 5,
+ "workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5",
+}
+
+
+def create_test_file() -> File:
+ """Create a test File object for serialization tests."""
+ return File(
+ tenant_id="test-tenant-id",
+ type=FileType.DOCUMENT,
+ transfer_method=FileTransferMethod.LOCAL_FILE,
+ related_id="test-file-id",
+ filename="test.txt",
+ extension=".txt",
+ mime_type="text/plain",
+ size=1024,
+ storage_key="test-storage-key",
+ )
+
+
+class TestSystemVariableSerialization:
+ """Focused tests for SystemVariable serialization/deserialization logic."""
+
+ def test_basic_deserialization(self):
+ """Test successful deserialization from JSON structure with all fields correctly mapped."""
+ # Test with complete data
+ system_var = SystemVariable(**COMPLETE_VALID_DATA)
+
+ # Verify all fields are correctly mapped
+ assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
+ assert system_var.app_id == COMPLETE_VALID_DATA["app_id"]
+ assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"]
+ assert system_var.query == COMPLETE_VALID_DATA["query"]
+ assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"]
+ assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"]
+ assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"]
+ assert system_var.files == []
+
+ # Test with minimal data (only required fields)
+ minimal_var = SystemVariable(**VALID_BASE_DATA)
+ assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
+ assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
+ assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
+ assert minimal_var.query is None
+ assert minimal_var.conversation_id is None
+ assert minimal_var.dialogue_count is None
+ assert minimal_var.workflow_execution_id is None
+ assert minimal_var.files == []
+
+ def test_alias_handling(self):
+ """Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic."""
+ workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5"
+
+ # Test workflow_run_id only (preferred alias)
+ data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
+ system_var1 = SystemVariable(**data_run_id)
+ assert system_var1.workflow_execution_id == workflow_id
+
+ # Test workflow_execution_id only (direct field name)
+ data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
+ system_var2 = SystemVariable(**data_execution_id)
+ assert system_var2.workflow_execution_id == workflow_id
+
+ # Test both present - workflow_run_id should take precedence
+ data_both = {
+ **VALID_BASE_DATA,
+ "workflow_execution_id": "should-be-ignored",
+ "workflow_run_id": workflow_id,
+ }
+ system_var3 = SystemVariable(**data_both)
+ assert system_var3.workflow_execution_id == workflow_id
+
+ # Test neither present - should be None
+ system_var4 = SystemVariable(**VALID_BASE_DATA)
+ assert system_var4.workflow_execution_id is None
+
+ def test_serialization_round_trip(self):
+ """Test that serialize → deserialize produces the same result with alias handling."""
+ # Create original SystemVariable
+ original = SystemVariable(**COMPLETE_VALID_DATA)
+
+ # Serialize to dict
+ serialized = original.model_dump(mode="json")
+
+ # Verify alias is used in serialization (workflow_run_id, not workflow_execution_id)
+ assert "workflow_run_id" in serialized
+ assert "workflow_execution_id" not in serialized
+ assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
+
+ # Deserialize back
+ deserialized = SystemVariable(**serialized)
+
+ # Verify all fields match after round-trip
+ assert deserialized.user_id == original.user_id
+ assert deserialized.app_id == original.app_id
+ assert deserialized.workflow_id == original.workflow_id
+ assert deserialized.query == original.query
+ assert deserialized.conversation_id == original.conversation_id
+ assert deserialized.dialogue_count == original.dialogue_count
+ assert deserialized.workflow_execution_id == original.workflow_execution_id
+ assert list(deserialized.files) == list(original.files)
+
+ def test_json_round_trip(self):
+ """Test JSON serialization/deserialization consistency with proper structure."""
+ # Create original SystemVariable
+ original = SystemVariable(**COMPLETE_VALID_DATA)
+
+ # Serialize to JSON string
+ json_str = original.model_dump_json()
+
+ # Parse JSON and verify structure
+ json_data = json.loads(json_str)
+ assert "workflow_run_id" in json_data
+ assert "workflow_execution_id" not in json_data
+ assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
+
+ # Deserialize from JSON data
+ deserialized = SystemVariable(**json_data)
+
+ # Verify key fields match after JSON round-trip
+ assert deserialized.workflow_execution_id == original.workflow_execution_id
+ assert deserialized.user_id == original.user_id
+ assert deserialized.app_id == original.app_id
+ assert deserialized.workflow_id == original.workflow_id
+
+ def test_files_field_deserialization(self):
+ """Test deserialization with File objects in the files field - SystemVariable specific logic."""
+ # Test with empty files list
+ data_empty = {**VALID_BASE_DATA, "files": []}
+ system_var_empty = SystemVariable(**data_empty)
+ assert system_var_empty.files == []
+
+ # Test with single File object
+ test_file = create_test_file()
+ data_single = {**VALID_BASE_DATA, "files": [test_file]}
+ system_var_single = SystemVariable(**data_single)
+ assert len(system_var_single.files) == 1
+ assert system_var_single.files[0].filename == "test.txt"
+ assert system_var_single.files[0].tenant_id == "test-tenant-id"
+
+ # Test with multiple File objects
+ file1 = File(
+ tenant_id="tenant1",
+ type=FileType.DOCUMENT,
+ transfer_method=FileTransferMethod.LOCAL_FILE,
+ related_id="file1",
+ filename="doc1.txt",
+ storage_key="key1",
+ )
+ file2 = File(
+ tenant_id="tenant2",
+ type=FileType.IMAGE,
+ transfer_method=FileTransferMethod.REMOTE_URL,
+ remote_url="https://example.com/image.jpg",
+ filename="image.jpg",
+ storage_key="key2",
+ )
+
+ data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
+ system_var_multiple = SystemVariable(**data_multiple)
+ assert len(system_var_multiple.files) == 2
+ assert system_var_multiple.files[0].filename == "doc1.txt"
+ assert system_var_multiple.files[1].filename == "image.jpg"
+
+ # Verify files field serialization/deserialization
+ serialized = system_var_multiple.model_dump(mode="json")
+ deserialized = SystemVariable(**serialized)
+ assert len(deserialized.files) == 2
+ assert deserialized.files[0].filename == "doc1.txt"
+ assert deserialized.files[1].filename == "image.jpg"
+
+ def test_alias_serialization_consistency(self):
+ """Test that alias handling works consistently in both serialization directions."""
+ workflow_id = "test-workflow-id"
+
+ # Create with workflow_run_id (alias)
+ data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
+ system_var = SystemVariable(**data_with_alias)
+
+ # Serialize and verify alias is used
+ serialized = system_var.model_dump()
+ assert serialized["workflow_run_id"] == workflow_id
+ assert "workflow_execution_id" not in serialized
+
+ # Deserialize and verify field mapping
+ deserialized = SystemVariable(**serialized)
+ assert deserialized.workflow_execution_id == workflow_id
+
+ # Test JSON serialization path
+ json_serialized = json.loads(system_var.model_dump_json())
+ assert json_serialized["workflow_run_id"] == workflow_id
+ assert "workflow_execution_id" not in json_serialized
+
+ json_deserialized = SystemVariable(**json_serialized)
+ assert json_deserialized.workflow_execution_id == workflow_id
+
+ def test_model_validator_serialization_logic(self):
+ """Test the custom model validator behavior for serialization scenarios."""
+ workflow_id = "test-workflow-execution-id"
+
+ # Test direct instantiation with workflow_execution_id (should work)
+ data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
+ system_var1 = SystemVariable(**data1)
+ assert system_var1.workflow_execution_id == workflow_id
+
+ # Test serialization of the above (should use alias)
+ serialized1 = system_var1.model_dump()
+ assert "workflow_run_id" in serialized1
+ assert serialized1["workflow_run_id"] == workflow_id
+
+ # Test both present - workflow_run_id takes precedence (validator logic)
+ data2 = {
+ **VALID_BASE_DATA,
+ "workflow_execution_id": "should-be-removed",
+ "workflow_run_id": workflow_id,
+ }
+ system_var2 = SystemVariable(**data2)
+ assert system_var2.workflow_execution_id == workflow_id
+
+ # Verify serialization consistency
+ serialized2 = system_var2.model_dump()
+ assert serialized2["workflow_run_id"] == workflow_id
+
+
+def test_constructor_with_extra_key():
+ # Test that SystemVariable should forbid extra keys
+ with pytest.raises(ValidationError):
+ # This should fail because there is an unexpected key.
+ SystemVariable(invalid_key=1) # type: ignore
diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py
index bb8d34fad5..c65b60cb4d 100644
--- a/api/tests/unit_tests/core/workflow/test_variable_pool.py
+++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py
@@ -1,17 +1,43 @@
+import uuid
+from collections import defaultdict
+
import pytest
-from pydantic import ValidationError
from core.file import File, FileTransferMethod, FileType
from core.variables import FileSegment, StringSegment
-from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
+from core.variables.segments import (
+ ArrayAnySegment,
+ ArrayFileSegment,
+ ArrayNumberSegment,
+ ArrayObjectSegment,
+ ArrayStringSegment,
+ FloatSegment,
+ IntegerSegment,
+ NoneSegment,
+ ObjectSegment,
+)
+from core.variables.variables import (
+ ArrayNumberVariable,
+ ArrayObjectVariable,
+ ArrayStringVariable,
+ FloatVariable,
+ IntegerVariable,
+ ObjectVariable,
+ StringVariable,
+ VariableUnion,
+)
+from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
+from core.workflow.system_variable import SystemVariable
from factories.variable_factory import build_segment, segment_to_variable
@pytest.fixture
def pool():
- return VariablePool(system_variables={}, user_inputs={})
+ return VariablePool(
+ system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"),
+ user_inputs={},
+ )
@pytest.fixture
@@ -52,18 +78,28 @@ def test_use_long_selector(pool):
class TestVariablePool:
def test_constructor(self):
- pool = VariablePool()
+ # Test with minimal required SystemVariable
+ minimal_system_vars = SystemVariable(
+ user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
+ )
+ pool = VariablePool(system_variables=minimal_system_vars)
+
+ # Test with all parameters
pool = VariablePool(
variable_dictionary={},
user_inputs={},
- system_variables={},
+ system_variables=minimal_system_vars,
environment_variables=[],
conversation_variables=[],
)
+ # Test with more complex SystemVariable
+ complex_system_vars = SystemVariable(
+ user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
+ )
pool = VariablePool(
user_inputs={"key": "value"},
- system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"},
+ system_variables=complex_system_vars,
environment_variables=[
segment_to_variable(
segment=build_segment(1),
@@ -80,6 +116,323 @@ class TestVariablePool:
],
)
- def test_constructor_with_invalid_system_variable_key(self):
- with pytest.raises(ValidationError):
- VariablePool(system_variables={"invalid_key": "value"}) # type: ignore
+ def test_get_system_variables(self):
+ sys_var = SystemVariable(
+ user_id="test_user_id",
+ app_id="test_app_id",
+ workflow_id="test_workflow_id",
+ workflow_execution_id="test_execution_123",
+ query="test query",
+ conversation_id="test_conv_id",
+ dialogue_count=5,
+ )
+ pool = VariablePool(system_variables=sys_var)
+
+ kv = [
+ ("user_id", sys_var.user_id),
+ ("app_id", sys_var.app_id),
+ ("workflow_id", sys_var.workflow_id),
+ ("workflow_run_id", sys_var.workflow_execution_id),
+ ("query", sys_var.query),
+ ("conversation_id", sys_var.conversation_id),
+ ("dialogue_count", sys_var.dialogue_count),
+ ]
+ for key, expected_value in kv:
+ segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key])
+ assert segment is not None
+ assert segment.value == expected_value
+
+
+class TestVariablePoolSerialization:
+ """Test cases for VariablePool serialization and deserialization using Pydantic's built-in methods.
+
+ These tests focus exclusively on serialization/deserialization logic to ensure that
+ VariablePool data can be properly serialized to dictionaries/JSON and reconstructed
+ while preserving all data integrity.
+ """
+
+ _NODE1_ID = "node_1"
+ _NODE2_ID = "node_2"
+ _NODE3_ID = "node_3"
+
+ def _create_pool_without_file(self):
+ # Create comprehensive system variables
+ system_vars = SystemVariable(
+ user_id="test_user_id",
+ app_id="test_app_id",
+ workflow_id="test_workflow_id",
+ workflow_execution_id="test_execution_123",
+ query="test query",
+ conversation_id="test_conv_id",
+ dialogue_count=5,
+ )
+
+ # Create environment variables with all types including ArrayFileVariable
+ env_vars: list[VariableUnion] = [
+ StringVariable(
+ id="env_string_id",
+ name="env_string",
+ value="env_string_value",
+ selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_string"],
+ ),
+ IntegerVariable(
+ id="env_integer_id",
+ name="env_integer",
+ value=1,
+ selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_integer"],
+ ),
+ FloatVariable(
+ id="env_float_id",
+ name="env_float",
+ value=1.0,
+ selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_float"],
+ ),
+ ]
+
+ # Create conversation variables with complex data
+ conv_vars: list[VariableUnion] = [
+ StringVariable(
+ id="conv_string_id",
+ name="conv_string",
+ value="conv_string_value",
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_string"],
+ ),
+ IntegerVariable(
+ id="conv_integer_id",
+ name="conv_integer",
+ value=1,
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_integer"],
+ ),
+ FloatVariable(
+ id="conv_float_id",
+ name="conv_float",
+ value=1.0,
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_float"],
+ ),
+ ObjectVariable(
+ id="conv_object_id",
+ name="conv_object",
+ value={"key": "value", "nested": {"data": 123}},
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_object"],
+ ),
+ ArrayStringVariable(
+ id="conv_array_string_id",
+ name="conv_array_string",
+ value=["conv_array_string_value"],
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_string"],
+ ),
+ ArrayNumberVariable(
+ id="conv_array_number_id",
+ name="conv_array_number",
+ value=[1, 1.0],
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_number"],
+ ),
+ ArrayObjectVariable(
+ id="conv_array_object_id",
+ name="conv_array_object",
+ value=[{"a": 1}, {"b": "2"}],
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_object"],
+ ),
+ ]
+
+ # Create comprehensive user inputs
+ user_inputs = {
+ "string_input": "test_value",
+ "number_input": 42,
+ "object_input": {"nested": {"key": "value"}},
+ "array_input": ["item1", "item2", "item3"],
+ }
+
+ # Create VariablePool
+ pool = VariablePool(
+ system_variables=system_vars,
+ user_inputs=user_inputs,
+ environment_variables=env_vars,
+ conversation_variables=conv_vars,
+ )
+ return pool
+
+ def _add_node_data_to_pool(self, pool: VariablePool, with_file=False):
+ test_file = File(
+ tenant_id="test_tenant_id",
+ type=FileType.DOCUMENT,
+ transfer_method=FileTransferMethod.LOCAL_FILE,
+ related_id="test_related_id",
+ remote_url="test_url",
+ filename="test_file.txt",
+ storage_key="test_storage_key",
+ )
+
+ # Add various segment types to variable dictionary
+ pool.add((self._NODE1_ID, "string_var"), StringSegment(value="test_string"))
+ pool.add((self._NODE1_ID, "int_var"), IntegerSegment(value=123))
+ pool.add((self._NODE1_ID, "float_var"), FloatSegment(value=45.67))
+ pool.add((self._NODE1_ID, "object_var"), ObjectSegment(value={"test": "data"}))
+ if with_file:
+ pool.add((self._NODE1_ID, "file_var"), FileSegment(value=test_file))
+ pool.add((self._NODE1_ID, "none_var"), NoneSegment())
+
+ # Add array segments including ArrayFileVariable
+ pool.add((self._NODE2_ID, "array_string"), ArrayStringSegment(value=["a", "b", "c"]))
+ pool.add((self._NODE2_ID, "array_number"), ArrayNumberSegment(value=[1, 2, 3]))
+ pool.add((self._NODE2_ID, "array_object"), ArrayObjectSegment(value=[{"a": 1}, {"b": 2}]))
+ if with_file:
+ pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
+ pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
+
+ # Add nested variables
+ pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
+
+ def test_system_variables(self):
+ sys_vars = SystemVariable(
+ user_id="test_user_id",
+ app_id="test_app_id",
+ workflow_id="test_workflow_id",
+ workflow_execution_id="test_execution_123",
+ query="test query",
+ conversation_id="test_conv_id",
+ dialogue_count=5,
+ )
+ pool = VariablePool(system_variables=sys_vars)
+ json = pool.model_dump_json()
+ pool2 = VariablePool.model_validate_json(json)
+ assert pool2.system_variables == sys_vars
+
+ for mode in ["json", "python"]:
+ dict_ = pool.model_dump(mode=mode)
+ pool2 = VariablePool.model_validate(dict_)
+ assert pool2.system_variables == sys_vars
+
+ def test_pool_without_file_vars(self):
+ pool = self._create_pool_without_file()
+ json = pool.model_dump_json()
+ pool2 = pool.model_validate_json(json)
+ assert pool2.system_variables == pool.system_variables
+ assert pool2.conversation_variables == pool.conversation_variables
+ assert pool2.environment_variables == pool.environment_variables
+ assert pool2.user_inputs == pool.user_inputs
+ assert pool2.variable_dictionary == pool.variable_dictionary
+ assert pool2 == pool
+
+ def test_basic_dictionary_round_trip(self):
+ """Test basic round-trip serialization: model_dump() → model_validate()"""
+ # Create a comprehensive VariablePool with all data types
+ original_pool = self._create_pool_without_file()
+ self._add_node_data_to_pool(original_pool)
+
+ # Serialize to dictionary using Pydantic's model_dump()
+ serialized_data = original_pool.model_dump()
+
+ # Verify serialized data structure
+ assert isinstance(serialized_data, dict)
+ assert "system_variables" in serialized_data
+ assert "user_inputs" in serialized_data
+ assert "environment_variables" in serialized_data
+ assert "conversation_variables" in serialized_data
+ assert "variable_dictionary" in serialized_data
+
+ # Deserialize back using Pydantic's model_validate()
+ reconstructed_pool = VariablePool.model_validate(serialized_data)
+
+ # Verify data integrity is preserved
+ self._assert_pools_equal(original_pool, reconstructed_pool)
+
+ def test_json_round_trip(self):
+ """Test JSON round-trip serialization: model_dump_json() → model_validate_json()"""
+ # Create a comprehensive VariablePool with all data types
+ original_pool = self._create_pool_without_file()
+ self._add_node_data_to_pool(original_pool)
+
+ # Serialize to JSON string using Pydantic's model_dump_json()
+ json_data = original_pool.model_dump_json()
+
+ # Verify JSON is valid string
+ assert isinstance(json_data, str)
+ assert len(json_data) > 0
+
+ # Deserialize back using Pydantic's model_validate_json()
+ reconstructed_pool = VariablePool.model_validate_json(json_data)
+
+ # Verify data integrity is preserved
+ self._assert_pools_equal(original_pool, reconstructed_pool)
+
+ def test_complex_data_serialization(self):
+ """Test serialization of complex data structures including ArrayFileVariable"""
+ original_pool = self._create_pool_without_file()
+ self._add_node_data_to_pool(original_pool, with_file=True)
+
+ # Test dictionary round-trip
+ dict_data = original_pool.model_dump()
+ reconstructed_dict = VariablePool.model_validate(dict_data)
+
+ # Test JSON round-trip
+ json_data = original_pool.model_dump_json()
+ reconstructed_json = VariablePool.model_validate_json(json_data)
+
+ # Verify both reconstructed pools are equivalent
+ self._assert_pools_equal(reconstructed_dict, reconstructed_json)
+ # TODO: assert the data for file object...
+
+ def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None:
+ """Assert that two VariablePools contain equivalent data"""
+
+ # Compare system variables
+ assert pool1.system_variables == pool2.system_variables
+
+ # Compare user inputs
+ assert dict(pool1.user_inputs) == dict(pool2.user_inputs)
+
+ # Compare environment variables count
+ assert pool1.environment_variables == pool2.environment_variables
+
+ # Compare conversation variables count
+ assert pool1.conversation_variables == pool2.conversation_variables
+
+ # Test key variable retrievals to ensure functionality is preserved
+ test_selectors = [
+ (SYSTEM_VARIABLE_NODE_ID, "user_id"),
+ (SYSTEM_VARIABLE_NODE_ID, "app_id"),
+ (ENVIRONMENT_VARIABLE_NODE_ID, "env_string"),
+ (ENVIRONMENT_VARIABLE_NODE_ID, "env_number"),
+ (CONVERSATION_VARIABLE_NODE_ID, "conv_string"),
+ (self._NODE1_ID, "string_var"),
+ (self._NODE1_ID, "int_var"),
+ (self._NODE1_ID, "float_var"),
+ (self._NODE2_ID, "array_string"),
+ (self._NODE2_ID, "array_number"),
+ (self._NODE3_ID, "nested", "deep", "var"),
+ ]
+
+ for selector in test_selectors:
+ val1 = pool1.get(selector)
+ val2 = pool2.get(selector)
+
+ # Both should exist or both should be None
+ assert (val1 is None) == (val2 is None)
+
+ if val1 is not None and val2 is not None:
+ # Values should be equal
+ assert val1.value == val2.value
+ # Value types should be the same (more important than exact class type)
+ assert val1.value_type == val2.value_type
+
+ def test_variable_pool_deserialization_default_dict(self):
+ variable_pool = VariablePool(
+ user_inputs={"a": 1, "b": "2"},
+ system_variables=SystemVariable(workflow_id=str(uuid.uuid4())),
+ environment_variables=[
+ StringVariable(name="str_var", value="a"),
+ ],
+ conversation_variables=[IntegerVariable(name="int_var", value=1)],
+ )
+ assert isinstance(variable_pool.variable_dictionary, defaultdict)
+ json = variable_pool.model_dump_json()
+ loaded = VariablePool.model_validate_json(json)
+ assert isinstance(loaded.variable_dictionary, defaultdict)
+
+ loaded.add(["non_exist_node", "a"], 1)
+
+ pool_dict = variable_pool.model_dump()
+ loaded = VariablePool.model_validate(pool_dict)
+ assert isinstance(loaded.variable_dictionary, defaultdict)
+ loaded.add(["non_exist_node", "a"], 1)
diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py
index 646de8bf3a..4866db1fdb 100644
--- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py
+++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py
@@ -18,10 +18,10 @@ from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
-from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from models.enums import CreatorUserRole
from models.model import AppMode
@@ -67,28 +67,25 @@ def real_app_generate_entity():
@pytest.fixture
def real_workflow_system_variables():
- return {
- SystemVariableKey.QUERY: "test query",
- SystemVariableKey.CONVERSATION_ID: "test-conversation-id",
- SystemVariableKey.USER_ID: "test-user-id",
- SystemVariableKey.APP_ID: "test-app-id",
- SystemVariableKey.WORKFLOW_ID: "test-workflow-id",
- SystemVariableKey.WORKFLOW_EXECUTION_ID: "test-workflow-run-id",
- }
+ return SystemVariable(
+ query="test query",
+ conversation_id="test-conversation-id",
+ user_id="test-user-id",
+ app_id="test-app-id",
+ workflow_id="test-workflow-id",
+ workflow_execution_id="test-workflow-run-id",
+ )
@pytest.fixture
def mock_node_execution_repository():
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
- repo.get_by_node_execution_id.return_value = None
- repo.get_running_executions.return_value = []
return repo
@pytest.fixture
def mock_workflow_execution_repository():
repo = MagicMock(spec=WorkflowExecutionRepository)
- repo.get.return_value = None
return repo
@@ -217,8 +214,8 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
started_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
- workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
+ # Pre-populate the cache with the workflow execution
+ workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_success(
@@ -251,11 +248,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
started_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
- workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
+ # Pre-populate the cache with the workflow execution
+ workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
- # Mock get_running_executions to return an empty list
- workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
+ # No running node executions in cache (empty cache)
# Call the method
result = workflow_cycle_manager.handle_workflow_run_failed(
@@ -289,8 +285,8 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
started_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
- workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
+ # Pre-populate the cache with the workflow execution
+ workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Create a mock event
event = MagicMock(spec=QueueNodeStartedEvent)
@@ -342,8 +338,8 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
started_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock the repository get method to return the real execution
- workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
+ # Pre-populate the cache with the workflow execution
+ workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution
# Call the method
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
@@ -351,11 +347,13 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
# Verify the result
assert result == workflow_execution
- # Test error case
- workflow_cycle_manager._workflow_execution_repository.get.return_value = None
+ # Test error case - clear cache
+ workflow_cycle_manager._workflow_execution_cache.clear()
# Expect an error when execution is not found
- with pytest.raises(ValueError):
+ from core.app.task_pipeline.exc import WorkflowRunNotFoundError
+
+ with pytest.raises(WorkflowRunNotFoundError):
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
@@ -384,8 +382,8 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
created_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock the repository to return the node execution
- workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
+ # Pre-populate the cache with the node execution
+ workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_success(
@@ -414,8 +412,8 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
started_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
- workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
+ # Pre-populate the cache with the workflow execution
+ workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_partial_success(
@@ -462,8 +460,8 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
created_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock the repository to return the node execution
- workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
+ # Pre-populate the cache with the node execution
+ workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_failed(
diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py
index f1cb937bb3..54bf6558bf 100644
--- a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py
+++ b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py
@@ -10,7 +10,7 @@ class TestAppendVariablesRecursively:
def test_append_simple_dict_value(self):
"""Test appending a simple dictionary value"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["output"]
variable_value = {"name": "John", "age": 30}
@@ -33,7 +33,7 @@ class TestAppendVariablesRecursively:
def test_append_object_segment_value(self):
"""Test appending an ObjectSegment value"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["result"]
@@ -60,7 +60,7 @@ class TestAppendVariablesRecursively:
def test_append_nested_dict_value(self):
"""Test appending a nested dictionary value"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["data"]
@@ -97,7 +97,7 @@ class TestAppendVariablesRecursively:
def test_append_non_dict_value(self):
"""Test appending a non-dictionary value (should not recurse)"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["simple"]
variable_value = "simple_string"
@@ -114,7 +114,7 @@ class TestAppendVariablesRecursively:
def test_append_segment_non_object_value(self):
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["text"]
variable_value = StringSegment(value="Hello World")
@@ -132,7 +132,7 @@ class TestAppendVariablesRecursively:
def test_append_empty_dict_value(self):
"""Test appending an empty dictionary value"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["empty"]
variable_value: dict[str, Any] = {}
diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py
index edd4c5e93e..4f2542a323 100644
--- a/api/tests/unit_tests/factories/test_variable_factory.py
+++ b/api/tests/unit_tests/factories/test_variable_factory.py
@@ -505,8 +505,8 @@ def test_build_segment_type_for_scalar():
size=1000,
)
cases = [
- TestCase(0, SegmentType.NUMBER),
- TestCase(0.0, SegmentType.NUMBER),
+ TestCase(0, SegmentType.INTEGER),
+ TestCase(0.0, SegmentType.FLOAT),
TestCase("", SegmentType.STRING),
TestCase(file, SegmentType.FILE),
]
@@ -531,14 +531,14 @@ class TestBuildSegmentWithType:
result = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result, IntegerSegment)
assert result.value == 42
- assert result.value_type == SegmentType.NUMBER
+ assert result.value_type == SegmentType.INTEGER
def test_number_type_float(self):
"""Test building a number segment with float value."""
result = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result, FloatSegment)
assert result.value == 3.14
- assert result.value_type == SegmentType.NUMBER
+ assert result.value_type == SegmentType.FLOAT
def test_object_type(self):
"""Test building an object segment with correct type."""
@@ -652,14 +652,14 @@ class TestBuildSegmentWithType:
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, None)
- assert "Expected string, but got None" in str(exc_info.value)
+ assert "expected string, but got None" in str(exc_info.value)
def test_type_mismatch_empty_list_to_non_array(self):
"""Test type mismatch when expecting non-array type but getting empty list."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, [])
- assert "Expected string, but got empty list" in str(exc_info.value)
+ assert "expected string, but got empty list" in str(exc_info.value)
def test_type_mismatch_object_to_array(self):
"""Test type mismatch when expecting array but getting object."""
@@ -674,19 +674,19 @@ class TestBuildSegmentWithType:
# Integer should work
result_int = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result_int, IntegerSegment)
- assert result_int.value_type == SegmentType.NUMBER
+ assert result_int.value_type == SegmentType.INTEGER
# Float should work
result_float = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result_float, FloatSegment)
- assert result_float.value_type == SegmentType.NUMBER
+ assert result_float.value_type == SegmentType.FLOAT
@pytest.mark.parametrize(
("segment_type", "value", "expected_class"),
[
(SegmentType.STRING, "test", StringSegment),
- (SegmentType.NUMBER, 42, IntegerSegment),
- (SegmentType.NUMBER, 3.14, FloatSegment),
+ (SegmentType.INTEGER, 42, IntegerSegment),
+ (SegmentType.FLOAT, 3.14, FloatSegment),
(SegmentType.OBJECT, {}, ObjectSegment),
(SegmentType.NONE, None, NoneSegment),
(SegmentType.ARRAY_STRING, [], ArrayStringSegment),
@@ -857,5 +857,5 @@ class TestBuildSegmentValueErrors:
# Verify they are processed as integers, not as errors
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
- assert true_segment.value_type == SegmentType.NUMBER
- assert false_segment.value_type == SegmentType.NUMBER
+ assert true_segment.value_type == SegmentType.INTEGER
+ assert false_segment.value_type == SegmentType.INTEGER
diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py
new file mode 100644
index 0000000000..39671077d4
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_login.py
@@ -0,0 +1,232 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+from flask import Flask, g
+from flask_login import LoginManager, UserMixin
+
+from libs.login import _get_user, current_user, login_required
+
+
+class MockUser(UserMixin):
+ """Mock user class for testing."""
+
+ def __init__(self, id: str, is_authenticated: bool = True):
+ self.id = id
+ self._is_authenticated = is_authenticated
+
+ @property
+ def is_authenticated(self):
+ return self._is_authenticated
+
+
+class TestLoginRequired:
+ """Test cases for login_required decorator."""
+
+ @pytest.fixture
+ def setup_app(self, app: Flask):
+ """Set up Flask app with login manager."""
+ # Initialize login manager
+ login_manager = LoginManager()
+ login_manager.init_app(app)
+
+ # Mock unauthorized handler
+ login_manager.unauthorized = MagicMock(return_value="Unauthorized")
+
+ # Add a dummy user loader to prevent exceptions
+ @login_manager.user_loader
+ def load_user(user_id):
+ return None
+
+ return app
+
+ def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
+ """Test that authenticated users can access protected views."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ with setup_app.test_request_context():
+ # Mock authenticated user
+ mock_user = MockUser("test_user", is_authenticated=True)
+ with patch("libs.login._get_user", return_value=mock_user):
+ result = protected_view()
+ assert result == "Protected content"
+
+ def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
+ """Test that unauthenticated users are redirected."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ with setup_app.test_request_context():
+ # Mock unauthenticated user
+ mock_user = MockUser("test_user", is_authenticated=False)
+ with patch("libs.login._get_user", return_value=mock_user):
+ result = protected_view()
+ assert result == "Unauthorized"
+ setup_app.login_manager.unauthorized.assert_called_once()
+
+ def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
+ """Test that LOGIN_DISABLED config bypasses authentication."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ with setup_app.test_request_context():
+ # Mock unauthenticated user and LOGIN_DISABLED
+ mock_user = MockUser("test_user", is_authenticated=False)
+ with patch("libs.login._get_user", return_value=mock_user):
+ with patch("libs.login.dify_config") as mock_config:
+ mock_config.LOGIN_DISABLED = True
+
+ result = protected_view()
+ assert result == "Protected content"
+ # Ensure unauthorized was not called
+ setup_app.login_manager.unauthorized.assert_not_called()
+
+ def test_options_request_bypasses_authentication(self, setup_app: Flask):
+ """Test that OPTIONS requests are exempt from authentication."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ with setup_app.test_request_context(method="OPTIONS"):
+ # Mock unauthenticated user
+ mock_user = MockUser("test_user", is_authenticated=False)
+ with patch("libs.login._get_user", return_value=mock_user):
+ result = protected_view()
+ assert result == "Protected content"
+ # Ensure unauthorized was not called
+ setup_app.login_manager.unauthorized.assert_not_called()
+
+ def test_flask_2_compatibility(self, setup_app: Flask):
+ """Test Flask 2.x compatibility with ensure_sync."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ # Mock Flask 2.x ensure_sync
+ setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content")
+
+ with setup_app.test_request_context():
+ mock_user = MockUser("test_user", is_authenticated=True)
+ with patch("libs.login._get_user", return_value=mock_user):
+ result = protected_view()
+ assert result == "Synced content"
+ setup_app.ensure_sync.assert_called_once()
+
+ def test_flask_1_compatibility(self, setup_app: Flask):
+ """Test Flask 1.x compatibility without ensure_sync."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ # Remove ensure_sync to simulate Flask 1.x
+ if hasattr(setup_app, "ensure_sync"):
+ delattr(setup_app, "ensure_sync")
+
+ with setup_app.test_request_context():
+ mock_user = MockUser("test_user", is_authenticated=True)
+ with patch("libs.login._get_user", return_value=mock_user):
+ result = protected_view()
+ assert result == "Protected content"
+
+
+class TestGetUser:
+ """Test cases for _get_user function."""
+
+ def test_get_user_returns_user_from_g(self, app: Flask):
+ """Test that _get_user returns user from g._login_user."""
+ mock_user = MockUser("test_user")
+
+ with app.test_request_context():
+ g._login_user = mock_user
+ user = _get_user()
+ assert user == mock_user
+ assert user.id == "test_user"
+
+ def test_get_user_loads_user_if_not_in_g(self, app: Flask):
+ """Test that _get_user loads user if not already in g."""
+ mock_user = MockUser("test_user")
+
+ # Mock login manager
+ login_manager = MagicMock()
+ login_manager._load_user = MagicMock()
+ app.login_manager = login_manager
+
+ with app.test_request_context():
+ # Simulate _load_user setting g._login_user
+ def side_effect():
+ g._login_user = mock_user
+
+ login_manager._load_user.side_effect = side_effect
+
+ user = _get_user()
+ assert user == mock_user
+ login_manager._load_user.assert_called_once()
+
+ def test_get_user_returns_none_without_request_context(self, app: Flask):
+ """Test that _get_user returns None outside request context."""
+ # Outside of request context
+ user = _get_user()
+ assert user is None
+
+
+class TestCurrentUser:
+ """Test cases for current_user proxy."""
+
+ def test_current_user_proxy_returns_authenticated_user(self, app: Flask):
+ """Test that current_user proxy returns authenticated user."""
+ mock_user = MockUser("test_user", is_authenticated=True)
+
+ with app.test_request_context():
+ with patch("libs.login._get_user", return_value=mock_user):
+ assert current_user.id == "test_user"
+ assert current_user.is_authenticated is True
+
+ def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
+ """Test that current_user proxy handles None user."""
+ with app.test_request_context():
+ with patch("libs.login._get_user", return_value=None):
+ # When _get_user returns None, accessing attributes should fail
+ # or current_user should evaluate to falsy
+ try:
+ # Try to access an attribute that would exist on a real user
+ _ = current_user.id
+ pytest.fail("Should have raised AttributeError")
+ except AttributeError:
+ # This is expected when current_user is None
+ pass
+
+ def test_current_user_proxy_thread_safety(self, app: Flask):
+ """Test that current_user proxy is thread-safe."""
+ import threading
+
+ results = {}
+
+ def check_user_in_thread(user_id: str, index: int):
+ with app.test_request_context():
+ mock_user = MockUser(user_id)
+ with patch("libs.login._get_user", return_value=mock_user):
+ results[index] = current_user.id
+
+ # Create multiple threads with different users
+ threads = []
+ for i in range(5):
+ thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i))
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join()
+
+ # Verify each thread got its own user
+ for i in range(5):
+ assert results[i] == f"user_{i}"
diff --git a/api/tests/unit_tests/libs/test_oauth_clients.py b/api/tests/unit_tests/libs/test_oauth_clients.py
new file mode 100644
index 0000000000..629d15b81a
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_oauth_clients.py
@@ -0,0 +1,249 @@
+import urllib.parse
+from unittest.mock import MagicMock, patch
+
+import pytest
+import requests
+
+from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
+
+
+class BaseOAuthTest:
+ """Base class for OAuth provider tests with common fixtures"""
+
+ @pytest.fixture
+ def oauth_config(self):
+ return {
+ "client_id": "test_client_id",
+ "client_secret": "test_client_secret",
+ "redirect_uri": "http://localhost/callback",
+ }
+
+ @pytest.fixture
+ def mock_response(self):
+ response = MagicMock()
+ response.json.return_value = {}
+ return response
+
+ def parse_auth_url(self, url):
+ """Helper to parse authorization URL"""
+ parsed = urllib.parse.urlparse(url)
+ params = urllib.parse.parse_qs(parsed.query)
+ return parsed, params
+
+
+class TestGitHubOAuth(BaseOAuthTest):
+ @pytest.fixture
+ def oauth(self, oauth_config):
+ return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
+
+ @pytest.mark.parametrize(
+ ("invite_token", "expected_state"),
+ [
+ (None, None),
+ ("test_invite_token", "test_invite_token"),
+ ("", None),
+ ],
+ )
+ def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
+ url = oauth.get_authorization_url(invite_token)
+ parsed, params = self.parse_auth_url(url)
+
+ assert parsed.scheme == "https"
+ assert parsed.netloc == "github.com"
+ assert parsed.path == "/login/oauth/authorize"
+ assert params["client_id"][0] == oauth_config["client_id"]
+ assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
+ assert params["scope"][0] == "user:email"
+
+ if expected_state:
+ assert params["state"][0] == expected_state
+ else:
+ assert "state" not in params
+
+ @pytest.mark.parametrize(
+ ("response_data", "expected_token", "should_raise"),
+ [
+ ({"access_token": "test_token"}, "test_token", False),
+ ({"error": "invalid_grant"}, None, True),
+ ({}, None, True),
+ ],
+ )
+ @patch("requests.post")
+ def test_should_retrieve_access_token(
+ self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
+ ):
+ mock_response.json.return_value = response_data
+ mock_post.return_value = mock_response
+
+ if should_raise:
+ with pytest.raises(ValueError) as exc_info:
+ oauth.get_access_token("test_code")
+ assert "Error in GitHub OAuth" in str(exc_info.value)
+ else:
+ token = oauth.get_access_token("test_code")
+ assert token == expected_token
+
+ @pytest.mark.parametrize(
+ ("user_data", "email_data", "expected_email"),
+ [
+ # User with primary email
+ (
+ {"id": 12345, "login": "testuser", "name": "Test User"},
+ [
+ {"email": "secondary@example.com", "primary": False},
+ {"email": "primary@example.com", "primary": True},
+ ],
+ "primary@example.com",
+ ),
+ # User with no emails - fallback to noreply
+ ({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
+ # User with only secondary email - fallback to noreply
+ (
+ {"id": 12345, "login": "testuser", "name": "Test User"},
+ [{"email": "secondary@example.com", "primary": False}],
+ "12345+testuser@users.noreply.github.com",
+ ),
+ ],
+ )
+ @patch("requests.get")
+ def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
+ user_response = MagicMock()
+ user_response.json.return_value = user_data
+
+ email_response = MagicMock()
+ email_response.json.return_value = email_data
+
+ mock_get.side_effect = [user_response, email_response]
+
+ user_info = oauth.get_user_info("test_token")
+
+ assert user_info.id == str(user_data["id"])
+ assert user_info.name == user_data["name"]
+ assert user_info.email == expected_email
+
+ @patch("requests.get")
+ def test_should_handle_network_errors(self, mock_get, oauth):
+ mock_get.side_effect = requests.exceptions.RequestException("Network error")
+
+ with pytest.raises(requests.exceptions.RequestException):
+ oauth.get_raw_user_info("test_token")
+
+
+class TestGoogleOAuth(BaseOAuthTest):
+ @pytest.fixture
+ def oauth(self, oauth_config):
+ return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
+
+ @pytest.mark.parametrize(
+ ("invite_token", "expected_state"),
+ [
+ (None, None),
+ ("test_invite_token", "test_invite_token"),
+ ("", None),
+ ],
+ )
+ def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
+ url = oauth.get_authorization_url(invite_token)
+ parsed, params = self.parse_auth_url(url)
+
+ assert parsed.scheme == "https"
+ assert parsed.netloc == "accounts.google.com"
+ assert parsed.path == "/o/oauth2/v2/auth"
+ assert params["client_id"][0] == oauth_config["client_id"]
+ assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
+ assert params["response_type"][0] == "code"
+ assert params["scope"][0] == "openid email"
+
+ if expected_state:
+ assert params["state"][0] == expected_state
+ else:
+ assert "state" not in params
+
+ @pytest.mark.parametrize(
+ ("response_data", "expected_token", "should_raise"),
+ [
+ ({"access_token": "test_token"}, "test_token", False),
+ ({"error": "invalid_grant"}, None, True),
+ ({}, None, True),
+ ],
+ )
+ @patch("requests.post")
+ def test_should_retrieve_access_token(
+ self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
+ ):
+ mock_response.json.return_value = response_data
+ mock_post.return_value = mock_response
+
+ if should_raise:
+ with pytest.raises(ValueError) as exc_info:
+ oauth.get_access_token("test_code")
+ assert "Error in Google OAuth" in str(exc_info.value)
+ else:
+ token = oauth.get_access_token("test_code")
+ assert token == expected_token
+
+ mock_post.assert_called_once_with(
+ oauth._TOKEN_URL,
+ data={
+ "client_id": oauth_config["client_id"],
+ "client_secret": oauth_config["client_secret"],
+ "code": "test_code",
+ "grant_type": "authorization_code",
+ "redirect_uri": oauth_config["redirect_uri"],
+ },
+ headers={"Accept": "application/json"},
+ )
+
+ @pytest.mark.parametrize(
+ ("user_data", "expected_name"),
+ [
+ ({"sub": "123", "email": "test@example.com", "email_verified": True}, ""),
+ ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
+ ],
+ )
+ @patch("requests.get")
+ def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
+ mock_response.json.return_value = user_data
+ mock_get.return_value = mock_response
+
+ user_info = oauth.get_user_info("test_token")
+
+ assert user_info.id == user_data["sub"]
+ assert user_info.name == expected_name
+ assert user_info.email == user_data["email"]
+
+ mock_get.assert_called_once_with(oauth._USER_INFO_URL, headers={"Authorization": "Bearer test_token"})
+
+ @pytest.mark.parametrize(
+ "exception_type",
+ [
+ requests.exceptions.HTTPError,
+ requests.exceptions.ConnectionError,
+ requests.exceptions.Timeout,
+ ],
+ )
+ @patch("requests.get")
+ def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
+ mock_response = MagicMock()
+ mock_response.raise_for_status.side_effect = exception_type("Error")
+ mock_get.return_value = mock_response
+
+ with pytest.raises(exception_type):
+ oauth.get_raw_user_info("invalid_token")
+
+
+class TestOAuthUserInfo:
+ @pytest.mark.parametrize(
+ "user_data",
+ [
+ {"id": "123", "name": "Test User", "email": "test@example.com"},
+ {"id": "456", "name": "", "email": "user@domain.com"},
+ {"id": "789", "name": "Another User", "email": "another@test.org"},
+ ],
+ )
+ def test_should_create_user_info_dataclass(self, user_data):
+ user_info = OAuthUserInfo(**user_data)
+
+ assert user_info.id == user_data["id"]
+ assert user_info.name == user_data["name"]
+ assert user_info.email == user_data["email"]
diff --git a/api/tests/unit_tests/libs/test_passport.py b/api/tests/unit_tests/libs/test_passport.py
new file mode 100644
index 0000000000..f33484c18d
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_passport.py
@@ -0,0 +1,205 @@
+from datetime import UTC, datetime, timedelta
+from unittest.mock import patch
+
+import jwt
+import pytest
+from werkzeug.exceptions import Unauthorized
+
+from libs.passport import PassportService
+
+
+class TestPassportService:
+ """Test PassportService JWT operations"""
+
+ @pytest.fixture
+ def passport_service(self):
+ """Create PassportService instance with test secret key"""
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "test-secret-key-for-testing"
+ return PassportService()
+
+ @pytest.fixture
+ def another_passport_service(self):
+ """Create another PassportService instance with different secret key"""
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "another-secret-key-for-testing"
+ return PassportService()
+
+ # Core functionality tests
+ def test_should_issue_and_verify_token(self, passport_service):
+ """Test complete JWT lifecycle: issue and verify"""
+ payload = {"user_id": "123", "app_code": "test-app"}
+ token = passport_service.issue(payload)
+
+ # Verify token format
+ assert isinstance(token, str)
+ assert len(token.split(".")) == 3 # JWT format: header.payload.signature
+
+ # Verify token content
+ decoded = passport_service.verify(token)
+ assert decoded == payload
+
+ def test_should_handle_different_payload_types(self, passport_service):
+ """Test issuing and verifying tokens with different payload types"""
+ test_cases = [
+ {"string": "value"},
+ {"number": 42},
+ {"float": 3.14},
+ {"boolean": True},
+ {"null": None},
+ {"array": [1, 2, 3]},
+ {"nested": {"key": "value"}},
+ {"unicode": "中文测试"},
+ {"emoji": "🔐"},
+ {}, # Empty payload
+ ]
+
+ for payload in test_cases:
+ token = passport_service.issue(payload)
+ decoded = passport_service.verify(token)
+ assert decoded == payload
+
+ # Security tests
+ def test_should_reject_modified_token(self, passport_service):
+ """Test that any modification to token invalidates it"""
+ token = passport_service.issue({"user": "test"})
+
+ # Test multiple modification points
+ test_positions = [0, len(token) // 3, len(token) // 2, len(token) - 1]
+
+ for pos in test_positions:
+ if pos < len(token) and token[pos] != ".":
+ # Change one character
+ tampered = token[:pos] + ("X" if token[pos] != "X" else "Y") + token[pos + 1 :]
+ with pytest.raises(Unauthorized):
+ passport_service.verify(tampered)
+
+ def test_should_reject_token_with_different_secret_key(self, passport_service, another_passport_service):
+ """Test key isolation - token from one service should not work with another"""
+ payload = {"user_id": "123", "app_code": "test-app"}
+ token = passport_service.issue(payload)
+
+ with pytest.raises(Unauthorized) as exc_info:
+ another_passport_service.verify(token)
+ assert str(exc_info.value) == "401 Unauthorized: Invalid token signature."
+
+ def test_should_use_hs256_algorithm(self, passport_service):
+ """Test that HS256 algorithm is used for signing"""
+ payload = {"test": "data"}
+ token = passport_service.issue(payload)
+
+ # Decode header without relying on JWT internals
+ # Use jwt.get_unverified_header which is a public API
+ header = jwt.get_unverified_header(token)
+ assert header["alg"] == "HS256"
+
+ def test_should_reject_token_with_wrong_algorithm(self, passport_service):
+ """Test rejection of token signed with different algorithm"""
+ payload = {"user_id": "123"}
+
+ # Create token with different algorithm
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "test-secret-key-for-testing"
+ # Create token with HS512 instead of HS256
+ wrong_alg_token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS512")
+
+ # Should fail because service expects HS256
+ # InvalidAlgorithmError is now caught by PyJWTError handler
+ with pytest.raises(Unauthorized) as exc_info:
+ passport_service.verify(wrong_alg_token)
+ assert str(exc_info.value) == "401 Unauthorized: Invalid token."
+
+ # Exception handling tests
+ def test_should_handle_invalid_tokens(self, passport_service):
+ """Test handling of various invalid token formats"""
+ invalid_tokens = [
+ ("not.a.token", "Invalid token."),
+ ("invalid-jwt-format", "Invalid token."),
+ ("xxx.yyy.zzz", "Invalid token."),
+ ("a.b", "Invalid token."), # Missing signature
+ ("", "Invalid token."), # Empty string
+ (" ", "Invalid token."), # Whitespace
+ (None, "Invalid token."), # None value
+ # Malformed base64
+ ("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.INVALID_BASE64!@#$.signature", "Invalid token."),
+ ]
+
+ for invalid_token, expected_message in invalid_tokens:
+ with pytest.raises(Unauthorized) as exc_info:
+ passport_service.verify(invalid_token)
+ assert expected_message in str(exc_info.value)
+
+ def test_should_reject_expired_token(self, passport_service):
+ """Test rejection of expired token"""
+ past_time = datetime.now(UTC) - timedelta(hours=1)
+ payload = {"user_id": "123", "exp": past_time.timestamp()}
+
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "test-secret-key-for-testing"
+ token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256")
+
+ with pytest.raises(Unauthorized) as exc_info:
+ passport_service.verify(token)
+ assert str(exc_info.value) == "401 Unauthorized: Token has expired."
+
+ # Configuration tests
+ def test_should_handle_empty_secret_key(self):
+ """Test behavior when SECRET_KEY is empty"""
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = ""
+ service = PassportService()
+
+ # Empty secret key should still work but is insecure
+ payload = {"test": "data"}
+ token = service.issue(payload)
+ decoded = service.verify(token)
+ assert decoded == payload
+
+ def test_should_handle_none_secret_key(self):
+ """Test behavior when SECRET_KEY is None"""
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = None
+ service = PassportService()
+
+ payload = {"test": "data"}
+ # JWT library will raise TypeError when secret is None
+ with pytest.raises((TypeError, jwt.exceptions.InvalidKeyError)):
+ service.issue(payload)
+
+ # Boundary condition tests
+ def test_should_handle_large_payload(self, passport_service):
+ """Test handling of large payload"""
+ # Test with 100KB instead of 1MB for faster tests
+ large_data = "x" * (100 * 1024)
+ payload = {"data": large_data}
+
+ token = passport_service.issue(payload)
+ decoded = passport_service.verify(token)
+
+ assert decoded["data"] == large_data
+
+ def test_should_handle_special_characters_in_payload(self, passport_service):
+ """Test handling of special characters in payload"""
+ special_payloads = [
+ {"special": "!@#$%^&*()"},
+ {"quotes": 'He said "Hello"'},
+ {"backslash": "path\\to\\file"},
+ {"newline": "line1\nline2"},
+ {"unicode": "🔐🔑🛡️"},
+ {"mixed": "Test123!@#中文🔐"},
+ ]
+
+ for payload in special_payloads:
+ token = passport_service.issue(payload)
+ decoded = passport_service.verify(token)
+ assert decoded == payload
+
+ def test_should_catch_generic_pyjwt_errors(self, passport_service):
+ """Test that generic PyJWTError exceptions are caught and converted to Unauthorized"""
+ # Mock jwt.decode to raise a generic PyJWTError
+ with patch("libs.passport.jwt.decode") as mock_decode:
+ mock_decode.side_effect = jwt.exceptions.PyJWTError("Generic JWT error")
+
+ with pytest.raises(Unauthorized) as exc_info:
+ passport_service.verify("some-token")
+ assert str(exc_info.value) == "401 Unauthorized: Invalid token."
diff --git a/api/tests/unit_tests/libs/test_uuid_utils.py b/api/tests/unit_tests/libs/test_uuid_utils.py
new file mode 100644
index 0000000000..7dbda95f45
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_uuid_utils.py
@@ -0,0 +1,351 @@
+import struct
+import time
+import uuid
+from unittest import mock
+
+import pytest
+from hypothesis import given
+from hypothesis import strategies as st
+
+from libs.uuid_utils import _create_uuidv7_bytes, uuidv7, uuidv7_boundary, uuidv7_timestamp
+
+
+# Tests for private helper function _create_uuidv7_bytes
+def test_create_uuidv7_bytes_basic_structure():
+ """Test basic byte structure creation."""
+ timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+ random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22"
+
+ result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ # Should be exactly 16 bytes
+ assert len(result) == 16
+ assert isinstance(result, bytes)
+
+ # Create UUID from bytes to verify it's valid
+ uuid_obj = uuid.UUID(bytes=result)
+ assert uuid_obj.version == 7
+
+
+def test_create_uuidv7_bytes_timestamp_encoding():
+ """Test timestamp is correctly encoded in first 48 bits."""
+ timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+ random_bytes = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+
+ result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ # Extract timestamp from first 6 bytes
+ timestamp_bytes = b"\x00\x00" + result[0:6]
+ extracted_timestamp = struct.unpack(">Q", timestamp_bytes)[0]
+
+ assert extracted_timestamp == timestamp_ms
+
+
+def test_create_uuidv7_bytes_version_bits():
+ """Test version bits are set to 7."""
+ timestamp_ms = 1609459200000
+ random_bytes = b"\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00" # Set first 2 bytes to all 1s
+
+ result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ # Extract version from bytes 6-7
+ version_and_rand_a = struct.unpack(">H", result[6:8])[0]
+ version = (version_and_rand_a >> 12) & 0x0F
+
+ assert version == 7
+
+
+def test_create_uuidv7_bytes_variant_bits():
+ """Test variant bits are set correctly."""
+ timestamp_ms = 1609459200000
+ random_bytes = b"\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00" # Set byte 8 to all 1s
+
+ result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ # Check variant bits in byte 8 (should be 10xxxxxx)
+ variant_byte = result[8]
+ variant_bits = (variant_byte >> 6) & 0b11
+
+ assert variant_bits == 0b10 # Should be binary 10
+
+
+def test_create_uuidv7_bytes_random_data():
+ """Test random bytes are placed correctly."""
+ timestamp_ms = 1609459200000
+ random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22"
+
+ result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ # Check random data A (12 bits from bytes 6-7, excluding version)
+ version_and_rand_a = struct.unpack(">H", result[6:8])[0]
+ rand_a = version_and_rand_a & 0x0FFF
+ expected_rand_a = struct.unpack(">H", random_bytes[0:2])[0] & 0x0FFF
+ assert rand_a == expected_rand_a
+
+ # Check random data B (bytes 8-15, with variant bits preserved)
+ # Byte 8 should have variant bits set but preserve lower 6 bits
+ expected_byte_8 = (random_bytes[2] & 0x3F) | 0x80
+ assert result[8] == expected_byte_8
+
+ # Bytes 9-15 should match random_bytes[3:10]
+ assert result[9:16] == random_bytes[3:10]
+
+
+def test_create_uuidv7_bytes_zero_random():
+ """Test with zero random bytes (boundary case)."""
+ timestamp_ms = 1609459200000
+ zero_random_bytes = b"\x00" * 10
+
+ result = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
+
+ # Should still be valid UUIDv7
+ uuid_obj = uuid.UUID(bytes=result)
+ assert uuid_obj.version == 7
+
+ # Version bits should be 0x7000
+ version_and_rand_a = struct.unpack(">H", result[6:8])[0]
+ assert version_and_rand_a == 0x7000
+
+ # Variant byte should be 0x80 (variant bits + zero random bits)
+ assert result[8] == 0x80
+
+ # Remaining bytes should be zero
+ assert result[9:16] == b"\x00" * 7
+
+
+def test_uuidv7_basic_generation():
+ """Test basic UUID generation produces valid UUIDv7."""
+ result = uuidv7()
+
+ # Should be a UUID object
+ assert isinstance(result, uuid.UUID)
+
+ # Should be version 7
+ assert result.version == 7
+
+ # Should have correct variant (RFC 4122 variant)
+ # Variant bits should be 10xxxxxx (0x80-0xBF range)
+ variant_byte = result.bytes[8]
+ assert (variant_byte >> 6) == 0b10
+
+
+def test_uuidv7_with_custom_timestamp():
+ """Test UUID generation with custom timestamp."""
+ custom_timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+ result = uuidv7(custom_timestamp)
+
+ assert isinstance(result, uuid.UUID)
+ assert result.version == 7
+
+ # Extract and verify timestamp
+ extracted_timestamp = uuidv7_timestamp(result)
+ assert isinstance(extracted_timestamp, int)
+ assert extracted_timestamp == custom_timestamp # Exact match for integer milliseconds
+
+
+def test_uuidv7_with_none_timestamp(monkeypatch):
+ """Test UUID generation with None timestamp uses current time."""
+ mock_time = 1609459200
+ mock_time_func = mock.Mock(return_value=mock_time)
+ monkeypatch.setattr("time.time", mock_time_func)
+ result = uuidv7(None)
+
+ assert isinstance(result, uuid.UUID)
+ assert result.version == 7
+
+ # Should use the mocked current time (converted to milliseconds)
+ assert mock_time_func.called
+ extracted_timestamp = uuidv7_timestamp(result)
+ assert extracted_timestamp == mock_time * 1000 # 1609459200.0 * 1000
+
+
+def test_uuidv7_time_ordering():
+ """Test that sequential UUIDs have increasing timestamps."""
+ # Generate UUIDs with incrementing timestamps (in milliseconds)
+ timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC
+ timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC
+ timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC
+
+ uuid1 = uuidv7(timestamp1)
+ uuid2 = uuidv7(timestamp2)
+ uuid3 = uuidv7(timestamp3)
+
+ # Extract timestamps
+ ts1 = uuidv7_timestamp(uuid1)
+ ts2 = uuidv7_timestamp(uuid2)
+ ts3 = uuidv7_timestamp(uuid3)
+
+ # Should be in ascending order
+ assert ts1 < ts2 < ts3
+
+ # UUIDs should be lexicographically ordered by their string representation
+ # due to time-ordering property of UUIDv7
+ uuid_strings = [str(uuid1), str(uuid2), str(uuid3)]
+ assert uuid_strings == sorted(uuid_strings)
+
+
+def test_uuidv7_uniqueness():
+ """Test that multiple calls generate different UUIDs."""
+ # Generate multiple UUIDs with the same timestamp (in milliseconds)
+ timestamp = 1609459200000
+ uuids = [uuidv7(timestamp) for _ in range(100)]
+
+ # All should be unique despite same timestamp (due to random bits)
+ assert len(set(uuids)) == 100
+
+ # All should have the same extracted timestamp
+ for uuid_obj in uuids:
+ extracted_ts = uuidv7_timestamp(uuid_obj)
+ assert extracted_ts == timestamp
+
+
+def test_uuidv7_timestamp_error_handling_wrong_version():
+ """Test error handling for non-UUIDv7 inputs."""
+
+ uuid_v4 = uuid.uuid4()
+ with pytest.raises(ValueError) as exc_ctx:
+ uuidv7_timestamp(uuid_v4)
+ assert "Expected UUIDv7 (version 7)" in str(exc_ctx.value)
+ assert f"got version {uuid_v4.version}" in str(exc_ctx.value)
+
+
+@given(st.integers(max_value=2**48 - 1, min_value=0))
+def test_uuidv7_timestamp_round_trip(timestamp_ms):
+ # Generate UUID with timestamp
+ uuid_obj = uuidv7(timestamp_ms)
+
+ # Extract timestamp back
+ extracted_timestamp = uuidv7_timestamp(uuid_obj)
+
+ # Should match exactly for integer millisecond timestamps
+ assert extracted_timestamp == timestamp_ms
+
+
+def test_uuidv7_timestamp_edge_cases():
+ """Test timestamp extraction with edge case values."""
+ # Test with very small timestamp
+ small_timestamp = 1 # 1ms after epoch
+ uuid_small = uuidv7(small_timestamp)
+ extracted_small = uuidv7_timestamp(uuid_small)
+ assert extracted_small == small_timestamp
+
+ # Test with large timestamp (year 2038+)
+ large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds
+ uuid_large = uuidv7(large_timestamp)
+ extracted_large = uuidv7_timestamp(uuid_large)
+ assert extracted_large == large_timestamp
+
+
+def test_uuidv7_boundary_basic_generation():
+ """Test basic boundary UUID generation with a known timestamp."""
+ timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+ result = uuidv7_boundary(timestamp)
+
+ # Should be a UUID object
+ assert isinstance(result, uuid.UUID)
+
+ # Should be version 7
+ assert result.version == 7
+
+ # Should have correct variant (RFC 4122 variant)
+ # Variant bits should be 10xxxxxx (0x80-0xBF range)
+ variant_byte = result.bytes[8]
+ assert (variant_byte >> 6) == 0b10
+
+
+def test_uuidv7_boundary_timestamp_extraction():
+ """Test that boundary UUID timestamp can be extracted correctly."""
+ timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+ boundary_uuid = uuidv7_boundary(timestamp)
+
+ # Extract timestamp using existing function
+ extracted_timestamp = uuidv7_timestamp(boundary_uuid)
+
+ # Should match exactly
+ assert extracted_timestamp == timestamp
+
+
+def test_uuidv7_boundary_deterministic():
+ """Test that boundary UUIDs are deterministic for same timestamp."""
+ timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+
+ # Generate multiple boundary UUIDs with same timestamp
+ uuid1 = uuidv7_boundary(timestamp)
+ uuid2 = uuidv7_boundary(timestamp)
+ uuid3 = uuidv7_boundary(timestamp)
+
+ # Should all be identical
+ assert uuid1 == uuid2 == uuid3
+ assert str(uuid1) == str(uuid2) == str(uuid3)
+
+
+def test_uuidv7_boundary_is_minimum():
+ """Test that boundary UUID is lexicographically smaller than regular UUIDs."""
+ timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+
+ # Generate boundary UUID
+ boundary_uuid = uuidv7_boundary(timestamp)
+
+ # Generate multiple regular UUIDs with same timestamp
+ regular_uuids = [uuidv7(timestamp) for _ in range(50)]
+
+ # Boundary UUID should be lexicographically smaller than all regular UUIDs
+ boundary_str = str(boundary_uuid)
+ for regular_uuid in regular_uuids:
+ regular_str = str(regular_uuid)
+ assert boundary_str < regular_str, f"Boundary {boundary_str} should be < regular {regular_str}"
+
+ # Also test with bytes comparison
+ boundary_bytes = boundary_uuid.bytes
+ for regular_uuid in regular_uuids:
+ regular_bytes = regular_uuid.bytes
+ assert boundary_bytes < regular_bytes
+
+
+def test_uuidv7_boundary_different_timestamps():
+ """Test that boundary UUIDs with different timestamps are ordered correctly."""
+ timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC
+ timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC
+ timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC
+
+ uuid1 = uuidv7_boundary(timestamp1)
+ uuid2 = uuidv7_boundary(timestamp2)
+ uuid3 = uuidv7_boundary(timestamp3)
+
+ # Extract timestamps to verify
+ ts1 = uuidv7_timestamp(uuid1)
+ ts2 = uuidv7_timestamp(uuid2)
+ ts3 = uuidv7_timestamp(uuid3)
+
+ # Should be in ascending order
+ assert ts1 < ts2 < ts3
+
+ # UUIDs should be lexicographically ordered
+ uuid_strings = [str(uuid1), str(uuid2), str(uuid3)]
+ assert uuid_strings == sorted(uuid_strings)
+
+ # Bytes should also be ordered
+ assert uuid1.bytes < uuid2.bytes < uuid3.bytes
+
+
+def test_uuidv7_boundary_edge_cases():
+ """Test boundary UUID generation with edge case timestamp values."""
+ # Test with timestamp 0 (Unix epoch)
+ epoch_uuid = uuidv7_boundary(0)
+ assert isinstance(epoch_uuid, uuid.UUID)
+ assert epoch_uuid.version == 7
+ assert uuidv7_timestamp(epoch_uuid) == 0
+
+ # Test with very large timestamp values
+ large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds
+ large_uuid = uuidv7_boundary(large_timestamp)
+ assert isinstance(large_uuid, uuid.UUID)
+ assert large_uuid.version == 7
+ assert uuidv7_timestamp(large_uuid) == large_timestamp
+
+ # Test with current time
+ current_time = int(time.time() * 1000)
+ current_uuid = uuidv7_boundary(current_time)
+ assert isinstance(current_uuid, uuid.UUID)
+ assert current_uuid.version == 7
+ assert uuidv7_timestamp(current_uuid) == current_time
diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py
index 643efb0a0c..c60800c493 100644
--- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py
+++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py
@@ -137,37 +137,6 @@ def test_save_with_existing_tenant_id(repository, session):
session_obj.merge.assert_called_once_with(modified_execution)
-def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
- """Test get_by_node_execution_id method."""
- session_obj, _ = session
- # Set up mock
- mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
- mock_stmt = mocker.MagicMock()
- mock_select.return_value = mock_stmt
- mock_stmt.where.return_value = mock_stmt
-
- # Create a properly configured mock execution
- mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
- configure_mock_execution(mock_execution)
- session_obj.scalar.return_value = mock_execution
-
- # Create a mock domain model to be returned by _to_domain_model
- mock_domain_model = mocker.MagicMock()
- # Mock the _to_domain_model method to return our mock domain model
- repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
-
- # Call method
- result = repository.get_by_node_execution_id("test-node-execution-id")
-
- # Assert select was called with correct parameters
- mock_select.assert_called_once()
- session_obj.scalar.assert_called_once_with(mock_stmt)
- # Assert _to_domain_model was called with the mock execution
- repository._to_domain_model.assert_called_once_with(mock_execution)
- # Assert the result is our mock domain model
- assert result is mock_domain_model
-
-
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
"""Test get_by_workflow_run method."""
session_obj, _ = session
@@ -202,88 +171,6 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
assert result[0] is mock_domain_model
-def test_get_running_executions(repository, session, mocker: MockerFixture):
- """Test get_running_executions method."""
- session_obj, _ = session
- # Set up mock
- mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
- mock_stmt = mocker.MagicMock()
- mock_select.return_value = mock_stmt
- mock_stmt.where.return_value = mock_stmt
-
- # Create a properly configured mock execution
- mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
- configure_mock_execution(mock_execution)
- session_obj.scalars.return_value.all.return_value = [mock_execution]
-
- # Create a mock domain model to be returned by _to_domain_model
- mock_domain_model = mocker.MagicMock()
- # Mock the _to_domain_model method to return our mock domain model
- repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
-
- # Call method
- result = repository.get_running_executions("test-workflow-run-id")
-
- # Assert select was called with correct parameters
- mock_select.assert_called_once()
- session_obj.scalars.assert_called_once_with(mock_stmt)
- # Assert _to_domain_model was called with the mock execution
- repository._to_domain_model.assert_called_once_with(mock_execution)
- # Assert the result contains our mock domain model
- assert len(result) == 1
- assert result[0] is mock_domain_model
-
-
-def test_update_via_save(repository, session):
- """Test updating an existing record via save method."""
- session_obj, _ = session
- # Create a mock execution
- execution = MagicMock(spec=WorkflowNodeExecutionModel)
- execution.tenant_id = None
- execution.app_id = None
- execution.inputs = None
- execution.process_data = None
- execution.outputs = None
- execution.metadata = None
-
- # Mock the to_db_model method to return the execution itself
- # This simulates the behavior of setting tenant_id and app_id
- repository.to_db_model = MagicMock(return_value=execution)
-
- # Call save method to update an existing record
- repository.save(execution)
-
- # Assert to_db_model was called with the execution
- repository.to_db_model.assert_called_once_with(execution)
-
- # Assert session.merge was called (for updates)
- session_obj.merge.assert_called_once_with(execution)
-
-
-def test_clear(repository, session, mocker: MockerFixture):
- """Test clear method."""
- session_obj, _ = session
- # Set up mock
- mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete")
- mock_stmt = mocker.MagicMock()
- mock_delete.return_value = mock_stmt
- mock_stmt.where.return_value = mock_stmt
-
- # Mock the execute result with rowcount
- mock_result = mocker.MagicMock()
- mock_result.rowcount = 5 # Simulate 5 records deleted
- session_obj.execute.return_value = mock_result
-
- # Call method
- repository.clear()
-
- # Assert delete was called with correct parameters
- mock_delete.assert_called_once_with(WorkflowNodeExecutionModel)
- mock_stmt.where.assert_called()
- session_obj.execute.assert_called_once_with(mock_stmt)
- session_obj.commit.assert_called_once()
-
-
def test_to_db_model(repository):
"""Test to_db_model method."""
# Create a domain model
diff --git a/api/tests/unit_tests/services/auth/__init__.py b/api/tests/unit_tests/services/auth/__init__.py
new file mode 100644
index 0000000000..852a892730
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/__init__.py
@@ -0,0 +1 @@
+# API authentication service test module
diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py
new file mode 100644
index 0000000000..f0e425e742
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py
@@ -0,0 +1,382 @@
+import json
+from unittest.mock import Mock, patch
+
+import pytest
+
+from models.source import DataSourceApiKeyAuthBinding
+from services.auth.api_key_auth_service import ApiKeyAuthService
+
+
+class TestApiKeyAuthService:
+ """API key authentication service security tests"""
+
+ def setup_method(self):
+ """Setup test fixtures"""
+ self.tenant_id = "test_tenant_123"
+ self.category = "search"
+ self.provider = "google"
+ self.binding_id = "binding_123"
+ self.mock_credentials = {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}}
+ self.mock_args = {"category": self.category, "provider": self.provider, "credentials": self.mock_credentials}
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_provider_auth_list_success(self, mock_session):
+ """Test get provider auth list - success scenario"""
+ # Mock database query result
+ mock_binding = Mock()
+ mock_binding.tenant_id = self.tenant_id
+ mock_binding.provider = self.provider
+ mock_binding.disabled = False
+
+ mock_session.query.return_value.filter.return_value.all.return_value = [mock_binding]
+
+ result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
+
+ assert len(result) == 1
+ assert result[0].tenant_id == self.tenant_id
+ mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_provider_auth_list_empty(self, mock_session):
+ """Test get provider auth list - empty result"""
+ mock_session.query.return_value.filter.return_value.all.return_value = []
+
+ result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
+
+ assert result == []
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_provider_auth_list_filters_disabled(self, mock_session):
+ """Test get provider auth list - filters disabled items"""
+ mock_session.query.return_value.filter.return_value.all.return_value = []
+
+ ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
+
+ # Verify filter conditions include disabled.is_(False)
+ filter_call = mock_session.query.return_value.filter.call_args[0]
+ assert len(filter_call) == 2 # tenant_id and disabled filter conditions
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ @patch("services.auth.api_key_auth_service.encrypter")
+ def test_create_provider_auth_success(self, mock_encrypter, mock_factory, mock_session):
+ """Test create provider auth - success scenario"""
+ # Mock successful auth validation
+ mock_auth_instance = Mock()
+ mock_auth_instance.validate_credentials.return_value = True
+ mock_factory.return_value = mock_auth_instance
+
+ # Mock encryption
+ encrypted_key = "encrypted_test_key_123"
+ mock_encrypter.encrypt_token.return_value = encrypted_key
+
+ # Mock database operations
+ mock_session.add = Mock()
+ mock_session.commit = Mock()
+
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
+
+ # Verify factory class calls
+ mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
+ mock_auth_instance.validate_credentials.assert_called_once()
+
+ # Verify encryption calls
+ mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, "test_secret_key_123")
+
+ # Verify database operations
+ mock_session.add.assert_called_once()
+ mock_session.commit.assert_called_once()
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ def test_create_provider_auth_validation_failed(self, mock_factory, mock_session):
+ """Test create provider auth - validation failed"""
+ # Mock failed auth validation
+ mock_auth_instance = Mock()
+ mock_auth_instance.validate_credentials.return_value = False
+ mock_factory.return_value = mock_auth_instance
+
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
+
+ # Verify no database operations when validation fails
+ mock_session.add.assert_not_called()
+ mock_session.commit.assert_not_called()
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ @patch("services.auth.api_key_auth_service.encrypter")
+ def test_create_provider_auth_encrypts_api_key(self, mock_encrypter, mock_factory, mock_session):
+ """Test create provider auth - ensures API key is encrypted"""
+ # Mock successful auth validation
+ mock_auth_instance = Mock()
+ mock_auth_instance.validate_credentials.return_value = True
+ mock_factory.return_value = mock_auth_instance
+
+ # Mock encryption
+ encrypted_key = "encrypted_test_key_123"
+ mock_encrypter.encrypt_token.return_value = encrypted_key
+
+ # Mock database operations
+ mock_session.add = Mock()
+ mock_session.commit = Mock()
+
+ args_copy = self.mock_args.copy()
+ original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore
+
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
+
+ # Verify original key is replaced with encrypted key
+ assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore
+ assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore
+
+ # Verify encryption function is called correctly
+ mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_auth_credentials_success(self, mock_session):
+ """Test get auth credentials - success scenario"""
+ # Mock database query result
+ mock_binding = Mock()
+ mock_binding.credentials = json.dumps(self.mock_credentials)
+ mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
+
+ result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
+
+ assert result == self.mock_credentials
+ mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_auth_credentials_not_found(self, mock_session):
+ """Test get auth credentials - not found"""
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+
+ result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
+
+ assert result is None
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_auth_credentials_filters_correctly(self, mock_session):
+ """Test get auth credentials - applies correct filters"""
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+
+ ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
+
+ # Verify filter conditions are correct
+ filter_call = mock_session.query.return_value.filter.call_args[0]
+ assert len(filter_call) == 4 # tenant_id, category, provider, disabled
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_auth_credentials_json_parsing(self, mock_session):
+ """Test get auth credentials - JSON parsing"""
+ # Mock credentials with special characters
+ special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}}
+
+ mock_binding = Mock()
+ mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False)
+ mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
+
+ result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
+
+ assert result == special_credentials
+ assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_delete_provider_auth_success(self, mock_session):
+ """Test delete provider auth - success scenario"""
+ # Mock database query result
+ mock_binding = Mock()
+ mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
+
+ ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
+
+ # Verify delete operations
+ mock_session.delete.assert_called_once_with(mock_binding)
+ mock_session.commit.assert_called_once()
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_delete_provider_auth_not_found(self, mock_session):
+ """Test delete provider auth - not found"""
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+
+ ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
+
+ # Verify no delete operations when not found
+ mock_session.delete.assert_not_called()
+ mock_session.commit.assert_not_called()
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_delete_provider_auth_filters_by_tenant(self, mock_session):
+ """Test delete provider auth - filters by tenant"""
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+
+ ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
+
+ # Verify filter conditions include tenant_id and binding_id
+ filter_call = mock_session.query.return_value.filter.call_args[0]
+ assert len(filter_call) == 2
+
+ def test_validate_api_key_auth_args_success(self):
+ """Test API key auth args validation - success scenario"""
+ # Should not raise any exception
+ ApiKeyAuthService.validate_api_key_auth_args(self.mock_args)
+
+ def test_validate_api_key_auth_args_missing_category(self):
+ """Test API key auth args validation - missing category"""
+ args = self.mock_args.copy()
+ del args["category"]
+
+ with pytest.raises(ValueError, match="category is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_empty_category(self):
+ """Test API key auth args validation - empty category"""
+ args = self.mock_args.copy()
+ args["category"] = ""
+
+ with pytest.raises(ValueError, match="category is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_missing_provider(self):
+ """Test API key auth args validation - missing provider"""
+ args = self.mock_args.copy()
+ del args["provider"]
+
+ with pytest.raises(ValueError, match="provider is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_empty_provider(self):
+ """Test API key auth args validation - empty provider"""
+ args = self.mock_args.copy()
+ args["provider"] = ""
+
+ with pytest.raises(ValueError, match="provider is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_missing_credentials(self):
+ """Test API key auth args validation - missing credentials"""
+ args = self.mock_args.copy()
+ del args["credentials"]
+
+ with pytest.raises(ValueError, match="credentials is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_empty_credentials(self):
+ """Test API key auth args validation - empty credentials"""
+ args = self.mock_args.copy()
+ args["credentials"] = None # type: ignore
+
+ with pytest.raises(ValueError, match="credentials is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_invalid_credentials_type(self):
+ """Test API key auth args validation - invalid credentials type"""
+ args = self.mock_args.copy()
+ args["credentials"] = "not_a_dict"
+
+ with pytest.raises(ValueError, match="credentials must be a dictionary"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_missing_auth_type(self):
+ """Test API key auth args validation - missing auth_type"""
+ args = self.mock_args.copy()
+ del args["credentials"]["auth_type"] # type: ignore
+
+ with pytest.raises(ValueError, match="auth_type is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_empty_auth_type(self):
+ """Test API key auth args validation - empty auth_type"""
+ args = self.mock_args.copy()
+ args["credentials"]["auth_type"] = "" # type: ignore
+
+ with pytest.raises(ValueError, match="auth_type is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ @pytest.mark.parametrize(
+ "malicious_input",
+ [
+ "",
+ "'; DROP TABLE users; --",
+ "../../../etc/passwd",
+ "\\x00\\x00", # null bytes
+ "A" * 10000, # very long input
+ ],
+ )
+ def test_validate_api_key_auth_args_malicious_input(self, malicious_input):
+ """Test API key auth args validation - malicious input"""
+ args = self.mock_args.copy()
+ args["category"] = malicious_input
+
+ # Verify parameter validator doesn't crash on malicious input
+ # Should validate normally rather than raising security-related exceptions
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ @patch("services.auth.api_key_auth_service.encrypter")
+ def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session):
+ """Test create provider auth - database error handling"""
+ # Mock successful auth validation
+ mock_auth_instance = Mock()
+ mock_auth_instance.validate_credentials.return_value = True
+ mock_factory.return_value = mock_auth_instance
+
+ # Mock encryption
+ mock_encrypter.encrypt_token.return_value = "encrypted_key"
+
+ # Mock database error
+ mock_session.commit.side_effect = Exception("Database error")
+
+ with pytest.raises(Exception, match="Database error"):
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_auth_credentials_invalid_json(self, mock_session):
+ """Test get auth credentials - invalid JSON"""
+ # Mock database returning invalid JSON
+ mock_binding = Mock()
+ mock_binding.credentials = "invalid json content"
+ mock_session.query.return_value.filter.return_value.first.return_value = mock_binding
+
+ with pytest.raises(json.JSONDecodeError):
+ ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ def test_create_provider_auth_factory_exception(self, mock_factory, mock_session):
+ """Test create provider auth - factory exception"""
+ # Mock factory raising exception
+ mock_factory.side_effect = Exception("Factory error")
+
+ with pytest.raises(Exception, match="Factory error"):
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ @patch("services.auth.api_key_auth_service.encrypter")
+ def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session):
+ """Test create provider auth - encryption exception"""
+ # Mock successful auth validation
+ mock_auth_instance = Mock()
+ mock_auth_instance.validate_credentials.return_value = True
+ mock_factory.return_value = mock_auth_instance
+
+ # Mock encryption exception
+ mock_encrypter.encrypt_token.side_effect = Exception("Encryption error")
+
+ with pytest.raises(Exception, match="Encryption error"):
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
+
+ def test_validate_api_key_auth_args_none_input(self):
+ """Test API key auth args validation - None input"""
+ with pytest.raises(TypeError):
+ ApiKeyAuthService.validate_api_key_auth_args(None)
+
+ def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
+ """Test API key auth args validation - dict credentials with list auth_type"""
+ args = self.mock_args.copy()
+ args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string
+
+ # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
+ # So this should not raise exception, this test should pass
+ ApiKeyAuthService.validate_api_key_auth_args(args)
diff --git a/api/tests/unit_tests/services/services_test_help.py b/api/tests/unit_tests/services/services_test_help.py
new file mode 100644
index 0000000000..c6b962f7fc
--- /dev/null
+++ b/api/tests/unit_tests/services/services_test_help.py
@@ -0,0 +1,59 @@
+from unittest.mock import MagicMock
+
+
+class ServiceDbTestHelper:
+ """
+ Helper class for service database query tests.
+ """
+
+ @staticmethod
+ def setup_db_query_filter_by_mock(mock_db, query_results):
+ """
+ Smart database query mock that responds based on model type and query parameters.
+
+ Args:
+ mock_db: Mock database session
+ query_results: Dict mapping (model_name, filter_key, filter_value) to return value
+ Example: {('Account', 'email', 'test@example.com'): mock_account}
+ """
+
+ def query_side_effect(model):
+ mock_query = MagicMock()
+
+ def filter_by_side_effect(**kwargs):
+ mock_filter_result = MagicMock()
+
+ def first_side_effect():
+ # Find matching result based on model and filter parameters
+ for (model_name, filter_key, filter_value), result in query_results.items():
+ if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value:
+ return result
+ return None
+
+ mock_filter_result.first.side_effect = first_side_effect
+
+ # Handle order_by calls for complex queries
+ def order_by_side_effect(*args, **kwargs):
+ mock_order_result = MagicMock()
+
+ def order_first_side_effect():
+ # Look for order_by results in the same query_results dict
+ for (model_name, filter_key, filter_value), result in query_results.items():
+ if (
+ model.__name__ == model_name
+ and filter_key == "order_by"
+ and filter_value == "first_available"
+ ):
+ return result
+ return None
+
+ mock_order_result.first.side_effect = order_first_side_effect
+ return mock_order_result
+
+ mock_filter_result.order_by.side_effect = order_by_side_effect
+ return mock_filter_result
+
+ mock_query.filter_by.side_effect = filter_by_side_effect
+ return mock_query
+
+ mock_db.session.query.side_effect = query_side_effect
diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py
new file mode 100644
index 0000000000..13900ab6d1
--- /dev/null
+++ b/api/tests/unit_tests/services/test_account_service.py
@@ -0,0 +1,1545 @@
+import json
+from datetime import datetime, timedelta
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from configs import dify_config
+from models.account import Account
+from services.account_service import AccountService, RegisterService, TenantService
+from services.errors.account import (
+ AccountAlreadyInTenantError,
+ AccountLoginError,
+ AccountNotFoundError,
+ AccountPasswordError,
+ AccountRegisterError,
+ CurrentPasswordIncorrectError,
+)
+from tests.unit_tests.services.services_test_help import ServiceDbTestHelper
+
+
+class TestAccountAssociatedDataFactory:
+ """Factory class for creating test data and mock objects for account service tests."""
+
+ @staticmethod
+ def create_account_mock(
+ account_id: str = "user-123",
+ email: str = "test@example.com",
+ name: str = "Test User",
+ status: str = "active",
+ password: str = "hashed_password",
+ password_salt: str = "salt",
+ interface_language: str = "en-US",
+ interface_theme: str = "light",
+ timezone: str = "UTC",
+ **kwargs,
+ ) -> MagicMock:
+ """Create a mock account with specified attributes."""
+ account = MagicMock(spec=Account)
+ account.id = account_id
+ account.email = email
+ account.name = name
+ account.status = status
+ account.password = password
+ account.password_salt = password_salt
+ account.interface_language = interface_language
+ account.interface_theme = interface_theme
+ account.timezone = timezone
+ # Set last_active_at to a datetime object that's older than 10 minutes
+ account.last_active_at = datetime.now() - timedelta(minutes=15)
+ account.initialized_at = None
+ for key, value in kwargs.items():
+ setattr(account, key, value)
+ return account
+
+ @staticmethod
+ def create_tenant_join_mock(
+ tenant_id: str = "tenant-456",
+ account_id: str = "user-123",
+ current: bool = True,
+ role: str = "normal",
+ **kwargs,
+ ) -> MagicMock:
+ """Create a mock tenant account join record."""
+ tenant_join = MagicMock()
+ tenant_join.tenant_id = tenant_id
+ tenant_join.account_id = account_id
+ tenant_join.current = current
+ tenant_join.role = role
+ for key, value in kwargs.items():
+ setattr(tenant_join, key, value)
+ return tenant_join
+
+ @staticmethod
+ def create_feature_service_mock(allow_register: bool = True):
+ """Create a mock feature service."""
+ mock_service = MagicMock()
+ mock_service.get_system_features.return_value.is_allow_register = allow_register
+ return mock_service
+
+ @staticmethod
+ def create_billing_service_mock(email_frozen: bool = False):
+ """Create a mock billing service."""
+ mock_service = MagicMock()
+ mock_service.is_email_in_freeze.return_value = email_frozen
+ return mock_service
+
+
+class TestAccountService:
+ """
+ Comprehensive unit tests for AccountService methods.
+
+ This test suite covers all account-related operations including:
+ - Authentication and login
+ - Account creation and registration
+ - Password management
+ - JWT token generation
+ - User loading and tenant management
+ - Error conditions and edge cases
+ """
+
+ @pytest.fixture
+ def mock_db_dependencies(self):
+ """Common mock setup for database dependencies."""
+ with patch("services.account_service.db") as mock_db:
+ mock_db.session.add = MagicMock()
+ mock_db.session.commit = MagicMock()
+ yield {
+ "db": mock_db,
+ }
+
+ @pytest.fixture
+ def mock_password_dependencies(self):
+ """Mock setup for password-related functions."""
+ with (
+ patch("services.account_service.compare_password") as mock_compare_password,
+ patch("services.account_service.hash_password") as mock_hash_password,
+ patch("services.account_service.valid_password") as mock_valid_password,
+ ):
+ yield {
+ "compare_password": mock_compare_password,
+ "hash_password": mock_hash_password,
+ "valid_password": mock_valid_password,
+ }
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.account_service.FeatureService") as mock_feature_service,
+ patch("services.account_service.BillingService") as mock_billing_service,
+ patch("services.account_service.PassportService") as mock_passport_service,
+ ):
+ yield {
+ "feature_service": mock_feature_service,
+ "billing_service": mock_billing_service,
+ "passport_service": mock_passport_service,
+ }
+
+ @pytest.fixture
+ def mock_db_with_autospec(self):
+ """
+ Mock database with autospec for more realistic behavior.
+ This approach preserves the actual method signatures and behavior.
+ """
+ with patch("services.account_service.db", autospec=True) as mock_db:
+ # Create a more realistic session mock
+ mock_session = MagicMock()
+ mock_db.session = mock_session
+
+ # Setup basic session methods
+ mock_session.add = MagicMock()
+ mock_session.commit = MagicMock()
+ mock_session.query = MagicMock()
+
+ yield mock_db
+
+ def _assert_database_operations_called(self, mock_db):
+ """Helper method to verify database operations were called."""
+ mock_db.session.commit.assert_called()
+
+ def _assert_database_operations_not_called(self, mock_db):
+ """Helper method to verify database operations were not called."""
+ mock_db.session.commit.assert_not_called()
+
+ def _assert_exception_raised(self, exception_type, callable_func, *args, **kwargs):
+ """Helper method to verify that specific exception is raised."""
+ with pytest.raises(exception_type):
+ callable_func(*args, **kwargs)
+
+ # ==================== Authentication Tests ====================
+
+ def test_authenticate_success(self, mock_db_dependencies, mock_password_dependencies):
+ """Test successful authentication with correct email and password."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Setup smart database query mock
+ query_results = {("Account", "email", "test@example.com"): mock_account}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ mock_password_dependencies["compare_password"].return_value = True
+
+ # Execute test
+ result = AccountService.authenticate("test@example.com", "password")
+
+ # Verify results
+ assert result == mock_account
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_authenticate_account_not_found(self, mock_db_dependencies):
+ """Test authentication when account does not exist."""
+ # Setup smart database query mock - no matching results
+ query_results = {("Account", "email", "notfound@example.com"): None}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ AccountNotFoundError, AccountService.authenticate, "notfound@example.com", "password"
+ )
+
+ def test_authenticate_account_banned(self, mock_db_dependencies):
+ """Test authentication when account is banned."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
+
+ # Setup smart database query mock
+ query_results = {("Account", "email", "banned@example.com"): mock_account}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Execute test and verify exception
+ self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
+
+ def test_authenticate_password_error(self, mock_db_dependencies, mock_password_dependencies):
+ """Test authentication with wrong password."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Setup smart database query mock
+ query_results = {("Account", "email", "test@example.com"): mock_account}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ mock_password_dependencies["compare_password"].return_value = False
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ AccountPasswordError, AccountService.authenticate, "test@example.com", "wrongpassword"
+ )
+
+ def test_authenticate_pending_account_activates(self, mock_db_dependencies, mock_password_dependencies):
+ """Test authentication for a pending account, which should activate on login."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
+
+ # Setup smart database query mock
+ query_results = {("Account", "email", "pending@example.com"): mock_account}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ mock_password_dependencies["compare_password"].return_value = True
+
+ # Execute test
+ result = AccountService.authenticate("pending@example.com", "password")
+
+ # Verify results
+ assert result == mock_account
+ assert mock_account.status == "active"
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ # ==================== Account Creation Tests ====================
+
+ def test_create_account_success(
+ self, mock_db_dependencies, mock_password_dependencies, mock_external_service_dependencies
+ ):
+ """Test successful account creation with all required parameters."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+ mock_password_dependencies["hash_password"].return_value = b"hashed_password"
+
+ # Execute test
+ result = AccountService.create_account(
+ email="test@example.com",
+ name="Test User",
+ interface_language="en-US",
+ password="password123",
+ interface_theme="light",
+ )
+
+ # Verify results
+ assert result.email == "test@example.com"
+ assert result.name == "Test User"
+ assert result.interface_language == "en-US"
+ assert result.interface_theme == "light"
+ assert result.password is not None
+ assert result.password_salt is not None
+ assert result.timezone is not None
+
+ # Verify database operations
+ mock_db_dependencies["db"].session.add.assert_called_once()
+ added_account = mock_db_dependencies["db"].session.add.call_args[0][0]
+ assert added_account.email == "test@example.com"
+ assert added_account.name == "Test User"
+ assert added_account.interface_language == "en-US"
+ assert added_account.interface_theme == "light"
+ assert added_account.password is not None
+ assert added_account.password_salt is not None
+ assert added_account.timezone is not None
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_create_account_registration_disabled(self, mock_external_service_dependencies):
+ """Test account creation when registration is disabled."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = False
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ Exception, # AccountNotFound
+ AccountService.create_account,
+ email="test@example.com",
+ name="Test User",
+ interface_language="en-US",
+ )
+
+ def test_create_account_email_frozen(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test account creation with frozen email address."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True
+ dify_config.BILLING_ENABLED = True
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ AccountRegisterError,
+ AccountService.create_account,
+ email="frozen@example.com",
+ name="Test User",
+ interface_language="en-US",
+ )
+ dify_config.BILLING_ENABLED = False
+
+ def test_create_account_without_password(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test account creation without password (for invite-based registration)."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Execute test
+ result = AccountService.create_account(
+ email="test@example.com",
+ name="Test User",
+ interface_language="zh-CN",
+ password=None,
+ interface_theme="dark",
+ )
+
+ # Verify results
+ assert result.email == "test@example.com"
+ assert result.name == "Test User"
+ assert result.interface_language == "zh-CN"
+ assert result.interface_theme == "dark"
+ assert result.password is None
+ assert result.password_salt is None
+ assert result.timezone is not None
+
+ # Verify database operations
+ mock_db_dependencies["db"].session.add.assert_called_once()
+ added_account = mock_db_dependencies["db"].session.add.call_args[0][0]
+ assert added_account.email == "test@example.com"
+ assert added_account.name == "Test User"
+ assert added_account.interface_language == "zh-CN"
+ assert added_account.interface_theme == "dark"
+ assert added_account.password is None
+ assert added_account.password_salt is None
+ assert added_account.timezone is not None
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ # ==================== Password Management Tests ====================
+
+ def test_update_account_password_success(self, mock_db_dependencies, mock_password_dependencies):
+ """Test successful password update with correct current password and valid new password."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_password_dependencies["compare_password"].return_value = True
+ mock_password_dependencies["valid_password"].return_value = None
+ mock_password_dependencies["hash_password"].return_value = b"new_hashed_password"
+
+ # Execute test
+ result = AccountService.update_account_password(mock_account, "old_password", "new_password123")
+
+ # Verify results
+ assert result == mock_account
+ assert mock_account.password is not None
+ assert mock_account.password_salt is not None
+
+ # Verify password validation was called
+ mock_password_dependencies["compare_password"].assert_called_once_with(
+ "old_password", "hashed_password", "salt"
+ )
+ mock_password_dependencies["valid_password"].assert_called_once_with("new_password123")
+
+ # Verify database operations
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_update_account_password_current_password_incorrect(self, mock_password_dependencies):
+ """Test password update with incorrect current password."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_password_dependencies["compare_password"].return_value = False
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ CurrentPasswordIncorrectError,
+ AccountService.update_account_password,
+ mock_account,
+ "wrong_password",
+ "new_password123",
+ )
+
+ # Verify password comparison was called
+ mock_password_dependencies["compare_password"].assert_called_once_with(
+ "wrong_password", "hashed_password", "salt"
+ )
+
+ def test_update_account_password_invalid_new_password(self, mock_password_dependencies):
+ """Test password update with invalid new password."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_password_dependencies["compare_password"].return_value = True
+ mock_password_dependencies["valid_password"].side_effect = ValueError("Password too short")
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ ValueError, AccountService.update_account_password, mock_account, "old_password", "short"
+ )
+
+ # Verify password validation was called
+ mock_password_dependencies["valid_password"].assert_called_once_with("short")
+
+ # ==================== User Loading Tests ====================
+
+ def test_load_user_success(self, mock_db_dependencies):
+ """Test successful user loading with current tenant."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
+
+ # Setup smart database query mock
+ query_results = {
+ ("Account", "id", "user-123"): mock_account,
+ ("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join,
+ }
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Mock datetime
+ with patch("services.account_service.datetime") as mock_datetime:
+ mock_now = datetime.now()
+ mock_datetime.now.return_value = mock_now
+ mock_datetime.UTC = "UTC"
+
+ # Execute test
+ result = AccountService.load_user("user-123")
+
+ # Verify results
+ assert result == mock_account
+ assert mock_account.set_tenant_id.called
+
+ def test_load_user_not_found(self, mock_db_dependencies):
+ """Test user loading when user does not exist."""
+ # Setup smart database query mock - no matching results
+ query_results = {("Account", "id", "non-existent-user"): None}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Execute test
+ result = AccountService.load_user("non-existent-user")
+
+ # Verify results
+ assert result is None
+
+ def test_load_user_banned(self, mock_db_dependencies):
+ """Test user loading when user is banned."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
+
+ # Setup smart database query mock
+ query_results = {("Account", "id", "user-123"): mock_account}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ Exception, # Unauthorized
+ AccountService.load_user,
+ "user-123",
+ )
+
+ def test_load_user_no_current_tenant(self, mock_db_dependencies):
+ """Test user loading when user has no current tenant but has available tenants."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
+
+ # Setup smart database query mock for complex scenario
+ query_results = {
+ ("Account", "id", "user-123"): mock_account,
+ ("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
+ ("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant
+ }
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Mock datetime
+ with patch("services.account_service.datetime") as mock_datetime:
+ mock_now = datetime.now()
+ mock_datetime.now.return_value = mock_now
+ mock_datetime.UTC = "UTC"
+
+ # Execute test
+ result = AccountService.load_user("user-123")
+
+ # Verify results
+ assert result == mock_account
+ assert mock_available_tenant.current is True
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_load_user_no_tenants(self, mock_db_dependencies):
+ """Test user loading when user has no tenants at all."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Setup smart database query mock for no tenants scenario
+ query_results = {
+ ("Account", "id", "user-123"): mock_account,
+ ("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
+ ("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants
+ }
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Mock datetime
+ with patch("services.account_service.datetime") as mock_datetime:
+ mock_now = datetime.now()
+ mock_datetime.now.return_value = mock_now
+ mock_datetime.UTC = "UTC"
+
+ # Execute test
+ result = AccountService.load_user("user-123")
+
+ # Verify results
+ assert result is None
+
+
+class TestTenantService:
+ """
+ Comprehensive unit tests for TenantService methods.
+
+ This test suite covers all tenant-related operations including:
+ - Tenant creation and management
+ - Member management and permissions
+ - Tenant switching
+ - Role updates and permission checks
+ - Error conditions and edge cases
+ """
+
+ @pytest.fixture
+ def mock_db_dependencies(self):
+ """Common mock setup for database dependencies."""
+ with patch("services.account_service.db") as mock_db:
+ mock_db.session.add = MagicMock()
+ mock_db.session.commit = MagicMock()
+ yield {
+ "db": mock_db,
+ }
+
+ @pytest.fixture
+ def mock_rsa_dependencies(self):
+ """Mock setup for RSA-related functions."""
+ with patch("services.account_service.generate_key_pair") as mock_generate_key_pair:
+ yield mock_generate_key_pair
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.account_service.FeatureService") as mock_feature_service,
+ patch("services.account_service.BillingService") as mock_billing_service,
+ ):
+ yield {
+ "feature_service": mock_feature_service,
+ "billing_service": mock_billing_service,
+ }
+
+ def _assert_database_operations_called(self, mock_db):
+ """Helper method to verify database operations were called."""
+ mock_db.session.commit.assert_called()
+
+ def _assert_exception_raised(self, exception_type, callable_func, *args, **kwargs):
+ """Helper method to verify that specific exception is raised."""
+ with pytest.raises(exception_type):
+ callable_func(*args, **kwargs)
+
+ # ==================== Tenant Creation Tests ====================
+
+ def test_create_owner_tenant_if_not_exist_new_user(
+ self, mock_db_dependencies, mock_rsa_dependencies, mock_external_service_dependencies
+ ):
+ """Test creating owner tenant for new user without existing tenants."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Setup smart database query mock - no existing tenant joins
+ query_results = {
+ ("TenantAccountJoin", "account_id", "user-123"): None,
+ ("TenantAccountJoin", "tenant_id", "tenant-456"): None, # For has_roles check
+ }
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Setup external service mocks
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.is_allow_create_workspace = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.license.workspaces.is_available.return_value = True
+
+ # Mock tenant creation
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.name = "Test User's Workspace"
+
+ # Mock database operations
+ mock_db_dependencies["db"].session.add = MagicMock()
+
+ # Mock RSA key generation
+ mock_rsa_dependencies.return_value = "mock_public_key"
+
+ # Mock has_roles method to return False (no existing owner)
+ with patch("services.account_service.TenantService.has_roles") as mock_has_roles:
+ mock_has_roles.return_value = False
+
+ # Mock Tenant creation to set proper ID
+ with patch("services.account_service.Tenant") as mock_tenant_class:
+ mock_tenant_instance = MagicMock()
+ mock_tenant_instance.id = "tenant-456"
+ mock_tenant_instance.name = "Test User's Workspace"
+ mock_tenant_class.return_value = mock_tenant_instance
+
+ # Execute test
+ TenantService.create_owner_tenant_if_not_exist(mock_account)
+
+ # Verify tenant was created with correct parameters
+ mock_db_dependencies["db"].session.add.assert_called()
+
+ # Get all calls to session.add
+ add_calls = mock_db_dependencies["db"].session.add.call_args_list
+
+ # Should have at least 2 calls: one for Tenant, one for TenantAccountJoin
+ assert len(add_calls) >= 2
+
+ # Verify Tenant was added with correct name
+ tenant_added = False
+ tenant_account_join_added = False
+
+ for call in add_calls:
+ added_object = call[0][0] # First argument of the call
+
+ # Check if it's a Tenant object
+ if hasattr(added_object, "name") and hasattr(added_object, "id"):
+ # This should be a Tenant object
+ assert added_object.name == "Test User's Workspace"
+ tenant_added = True
+
+ # Check if it's a TenantAccountJoin object
+ elif (
+ hasattr(added_object, "tenant_id")
+ and hasattr(added_object, "account_id")
+ and hasattr(added_object, "role")
+ ):
+ # This should be a TenantAccountJoin object
+ assert added_object.tenant_id is not None
+ assert added_object.account_id == "user-123"
+ assert added_object.role == "owner"
+ tenant_account_join_added = True
+
+ assert tenant_added, "Tenant object was not added to database"
+ assert tenant_account_join_added, "TenantAccountJoin object was not added to database"
+
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+ assert mock_rsa_dependencies.called, "RSA key generation was not called"
+
+ # ==================== Member Management Tests ====================
+
+ def test_create_tenant_member_success(self, mock_db_dependencies):
+ """Test successful tenant member creation."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Setup smart database query mock - no existing member
+ query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): None}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Mock database operations
+ mock_db_dependencies["db"].session.add = MagicMock()
+
+ # Execute test
+ result = TenantService.create_tenant_member(mock_tenant, mock_account, "normal")
+
+ # Verify member was created with correct parameters
+ assert result is not None
+ mock_db_dependencies["db"].session.add.assert_called_once()
+
+ # Verify the TenantAccountJoin object was added with correct parameters
+ added_tenant_account_join = mock_db_dependencies["db"].session.add.call_args[0][0]
+ assert added_tenant_account_join.tenant_id == "tenant-456"
+ assert added_tenant_account_join.account_id == "user-123"
+ assert added_tenant_account_join.role == "normal"
+
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ # ==================== Tenant Switching Tests ====================
+
+ def test_switch_tenant_success(self):
+ """Test successful tenant switching."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock(
+ tenant_id="tenant-456", account_id="user-123", current=False
+ )
+
+ # Mock the complex query in switch_tenant method
+ with patch("services.account_service.db") as mock_db:
+ # Mock the join query that returns the tenant_account_join
+ mock_query = MagicMock()
+ mock_filter = MagicMock()
+ mock_filter.first.return_value = mock_tenant_join
+ mock_query.filter.return_value = mock_filter
+ mock_query.join.return_value = mock_query
+ mock_db.session.query.return_value = mock_query
+
+ # Execute test
+ TenantService.switch_tenant(mock_account, "tenant-456")
+
+ # Verify tenant was switched
+ assert mock_tenant_join.current is True
+ self._assert_database_operations_called(mock_db)
+
+ def test_switch_tenant_no_tenant_id(self):
+ """Test tenant switching without providing tenant ID."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Execute test and verify exception
+ self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None)
+
+ # ==================== Role Management Tests ====================
+
+ def test_update_member_role_success(self):
+ """Test successful member role update."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_member = TestAccountAssociatedDataFactory.create_account_mock(account_id="member-789")
+ mock_operator = TestAccountAssociatedDataFactory.create_account_mock(account_id="operator-123")
+ mock_target_join = TestAccountAssociatedDataFactory.create_tenant_join_mock(
+ tenant_id="tenant-456", account_id="member-789", role="normal"
+ )
+ mock_operator_join = TestAccountAssociatedDataFactory.create_tenant_join_mock(
+ tenant_id="tenant-456", account_id="operator-123", role="owner"
+ )
+
+ # Mock the database queries in update_member_role method
+ with patch("services.account_service.db") as mock_db:
+ # Mock the first query for operator permission check
+ mock_query1 = MagicMock()
+ mock_filter1 = MagicMock()
+ mock_filter1.first.return_value = mock_operator_join
+ mock_query1.filter_by.return_value = mock_filter1
+
+ # Mock the second query for target member
+ mock_query2 = MagicMock()
+ mock_filter2 = MagicMock()
+ mock_filter2.first.return_value = mock_target_join
+ mock_query2.filter_by.return_value = mock_filter2
+
+ # Make the query method return different mocks for different calls
+ mock_db.session.query.side_effect = [mock_query1, mock_query2]
+
+ # Execute test
+ TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator)
+
+ # Verify role was updated
+ assert mock_target_join.role == "admin"
+ self._assert_database_operations_called(mock_db)
+
+ # ==================== Permission Check Tests ====================
+
+ def test_check_member_permission_success(self, mock_db_dependencies):
+ """Test successful member permission check."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_operator = TestAccountAssociatedDataFactory.create_account_mock(account_id="operator-123")
+ mock_member = TestAccountAssociatedDataFactory.create_account_mock(account_id="member-789")
+ mock_operator_join = TestAccountAssociatedDataFactory.create_tenant_join_mock(
+ tenant_id="tenant-456", account_id="operator-123", role="owner"
+ )
+
+ # Setup smart database query mock
+ query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): mock_operator_join}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Execute test - should not raise exception
+ TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add")
+
+ def test_check_member_permission_operate_self(self):
+ """Test member permission check when operator tries to operate self."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_operator = TestAccountAssociatedDataFactory.create_account_mock(account_id="operator-123")
+
+ # Execute test and verify exception
+ from services.errors.account import CannotOperateSelfError
+
+ self._assert_exception_raised(
+ CannotOperateSelfError,
+ TenantService.check_member_permission,
+ mock_tenant,
+ mock_operator,
+ mock_operator, # Same as operator
+ "add",
+ )
+
+
+class TestRegisterService:
+ """
+ Comprehensive unit tests for RegisterService methods.
+
+ This test suite covers all registration-related operations including:
+ - System setup
+ - Account registration
+ - Member invitation
+ - Token management
+ - Invitation validation
+ - Error conditions and edge cases
+ """
+
+ @pytest.fixture
+ def mock_db_dependencies(self):
+ """Common mock setup for database dependencies."""
+ with patch("services.account_service.db") as mock_db:
+ mock_db.session.add = MagicMock()
+ mock_db.session.commit = MagicMock()
+ mock_db.session.begin_nested = MagicMock()
+ mock_db.session.rollback = MagicMock()
+ yield {
+ "db": mock_db,
+ }
+
+ @pytest.fixture
+ def mock_redis_dependencies(self):
+ """Mock setup for Redis-related functions."""
+ with patch("services.account_service.redis_client") as mock_redis:
+ yield mock_redis
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.account_service.FeatureService") as mock_feature_service,
+ patch("services.account_service.BillingService") as mock_billing_service,
+ patch("services.account_service.PassportService") as mock_passport_service,
+ ):
+ yield {
+ "feature_service": mock_feature_service,
+ "billing_service": mock_billing_service,
+ "passport_service": mock_passport_service,
+ }
+
+ @pytest.fixture
+ def mock_task_dependencies(self):
+ """Mock setup for task dependencies."""
+ with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail:
+ yield mock_send_mail
+
+ def _assert_database_operations_called(self, mock_db):
+ """Helper method to verify database operations were called."""
+ mock_db.session.commit.assert_called()
+
+ def _assert_database_operations_not_called(self, mock_db):
+ """Helper method to verify database operations were not called."""
+ mock_db.session.commit.assert_not_called()
+
+ def _assert_exception_raised(self, exception_type, callable_func, *args, **kwargs):
+ """Helper method to verify that specific exception is raised."""
+ with pytest.raises(exception_type):
+ callable_func(*args, **kwargs)
+
+ # ==================== Setup Tests ====================
+
+ def test_setup_success(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test successful system setup."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.return_value = mock_account
+
+ # Mock TenantService.create_owner_tenant_if_not_exist
+ with patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_tenant:
+ # Mock DifySetup
+ with patch("services.account_service.DifySetup") as mock_dify_setup:
+ mock_dify_setup_instance = MagicMock()
+ mock_dify_setup.return_value = mock_dify_setup_instance
+
+ # Execute test
+ RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1")
+
+ # Verify results
+ mock_create_account.assert_called_once_with(
+ email="admin@example.com",
+ name="Admin User",
+ interface_language="en-US",
+ password="password123",
+ is_setup=True,
+ )
+ mock_create_tenant.assert_called_once_with(account=mock_account, is_setup=True)
+ mock_dify_setup.assert_called_once()
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_setup_failure_rollback(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test setup failure with proper rollback."""
+ # Setup mocks to simulate failure
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account to raise exception
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.side_effect = Exception("Database error")
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ ValueError,
+ RegisterService.setup,
+ "admin@example.com",
+ "Admin User",
+ "password123",
+ "192.168.1.1",
+ )
+
+ # Verify rollback operations were called
+ mock_db_dependencies["db"].session.query.assert_called()
+
+ # ==================== Registration Tests ====================
+
+ def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test successful account registration."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.is_allow_create_workspace = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.license.workspaces.is_available.return_value = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.return_value = mock_account
+
+ # Mock TenantService.create_tenant and create_tenant_member
+ with (
+ patch("services.account_service.TenantService.create_tenant") as mock_create_tenant,
+ patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
+ patch("services.account_service.tenant_was_created") as mock_event,
+ ):
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_create_tenant.return_value = mock_tenant
+
+ # Execute test
+ result = RegisterService.register(
+ email="test@example.com",
+ name="Test User",
+ password="password123",
+ language="en-US",
+ )
+
+ # Verify results
+ assert result == mock_account
+ assert result.status == "active"
+ assert result.initialized_at is not None
+ mock_create_account.assert_called_once_with(
+ email="test@example.com",
+ name="Test User",
+ interface_language="en-US",
+ password="password123",
+ is_setup=False,
+ )
+ mock_create_tenant.assert_called_once_with("Test User's Workspace")
+ mock_create_member.assert_called_once_with(mock_tenant, mock_account, role="owner")
+ mock_event.send.assert_called_once_with(mock_tenant)
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test account registration with OAuth integration."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.is_allow_create_workspace = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.license.workspaces.is_available.return_value = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account and link_account_integrate
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ with (
+ patch("services.account_service.AccountService.create_account") as mock_create_account,
+ patch("services.account_service.AccountService.link_account_integrate") as mock_link_account,
+ ):
+ mock_create_account.return_value = mock_account
+
+ # Mock TenantService methods
+ with (
+ patch("services.account_service.TenantService.create_tenant") as mock_create_tenant,
+ patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
+ patch("services.account_service.tenant_was_created") as mock_event,
+ ):
+ mock_tenant = MagicMock()
+ mock_create_tenant.return_value = mock_tenant
+
+ # Execute test
+ result = RegisterService.register(
+ email="test@example.com",
+ name="Test User",
+ password=None,
+ open_id="oauth123",
+ provider="google",
+ language="en-US",
+ )
+
+ # Verify results
+ assert result == mock_account
+ mock_link_account.assert_called_once_with("google", "oauth123", mock_account)
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_register_with_pending_status(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test account registration with pending status."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.is_allow_create_workspace = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.license.workspaces.is_available.return_value = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.return_value = mock_account
+
+ # Mock TenantService methods
+ with (
+ patch("services.account_service.TenantService.create_tenant") as mock_create_tenant,
+ patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
+ patch("services.account_service.tenant_was_created") as mock_event,
+ ):
+ mock_tenant = MagicMock()
+ mock_create_tenant.return_value = mock_tenant
+
+ # Execute test with pending status
+ from models.account import AccountStatus
+
+ result = RegisterService.register(
+ email="test@example.com",
+ name="Test User",
+ password="password123",
+ language="en-US",
+ status=AccountStatus.PENDING,
+ )
+
+ # Verify results
+ assert result == mock_account
+ assert result.status == "pending"
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_register_workspace_not_allowed(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test registration when workspace creation is not allowed."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.is_allow_create_workspace = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.license.workspaces.is_available.return_value = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.return_value = mock_account
+
+ # Execute test and verify exception
+ from services.errors.workspace import WorkSpaceNotAllowedCreateError
+
+ with patch("services.account_service.TenantService.create_tenant") as mock_create_tenant:
+ mock_create_tenant.side_effect = WorkSpaceNotAllowedCreateError()
+
+ self._assert_exception_raised(
+ AccountRegisterError,
+ RegisterService.register,
+ email="test@example.com",
+ name="Test User",
+ password="password123",
+ language="en-US",
+ )
+
+ # Verify rollback was called
+ mock_db_dependencies["db"].session.rollback.assert_called()
+
+ def test_register_general_exception(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test registration with general exception handling."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account to raise exception
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.side_effect = Exception("Unexpected error")
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ AccountRegisterError,
+ RegisterService.register,
+ email="test@example.com",
+ name="Test User",
+ password="password123",
+ language="en-US",
+ )
+
+ # Verify rollback was called
+ mock_db_dependencies["db"].session.rollback.assert_called()
+
+ # ==================== Member Invitation Tests ====================
+
+ def test_invite_new_member_new_account(self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies):
+ """Test inviting a new member who doesn't have an account."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.name = "Test Workspace"
+ mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
+
+ # Mock database queries - need to mock the Session query
+ mock_session = MagicMock()
+ mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
+
+ with patch("services.account_service.Session") as mock_session_class:
+ mock_session_class.return_value.__enter__.return_value = mock_session
+ mock_session_class.return_value.__exit__.return_value = None
+
+ # Mock RegisterService.register
+ mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="new-user-456", email="newuser@example.com", name="newuser", status="pending"
+ )
+ with patch("services.account_service.RegisterService.register") as mock_register:
+ mock_register.return_value = mock_new_account
+
+ # Mock TenantService methods
+ with (
+ patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
+ patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
+ patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant,
+ patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token,
+ ):
+ mock_generate_token.return_value = "invite-token-123"
+
+ # Execute test
+ result = RegisterService.invite_new_member(
+ tenant=mock_tenant,
+ email="newuser@example.com",
+ language="en-US",
+ role="normal",
+ inviter=mock_inviter,
+ )
+
+ # Verify results
+ assert result == "invite-token-123"
+ mock_register.assert_called_once_with(
+ email="newuser@example.com",
+ name="newuser",
+ language="en-US",
+ status="pending",
+ is_setup=True,
+ )
+ mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
+ mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
+ mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account)
+ mock_task_dependencies.delay.assert_called_once()
+
+ def test_invite_new_member_existing_account(
+ self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
+ ):
+ """Test inviting a new member who already has an account."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.name = "Test Workspace"
+ mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
+ mock_existing_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="existing-user-456", email="existing@example.com", status="pending"
+ )
+
+ # Mock database queries - need to mock the Session query
+ mock_session = MagicMock()
+ mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
+
+ with patch("services.account_service.Session") as mock_session_class:
+ mock_session_class.return_value.__enter__.return_value = mock_session
+ mock_session_class.return_value.__exit__.return_value = None
+
+ # Mock the db.session.query for TenantAccountJoin
+ mock_db_query = MagicMock()
+ mock_db_query.filter_by.return_value.first.return_value = None # No existing member
+ mock_db_dependencies["db"].session.query.return_value = mock_db_query
+
+ # Mock TenantService methods
+ with (
+ patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
+ patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
+ patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token,
+ ):
+ mock_generate_token.return_value = "invite-token-123"
+
+ # Execute test
+ result = RegisterService.invite_new_member(
+ tenant=mock_tenant,
+ email="existing@example.com",
+ language="en-US",
+ role="normal",
+ inviter=mock_inviter,
+ )
+
+ # Verify results
+ assert result == "invite-token-123"
+ mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
+ mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
+ mock_task_dependencies.delay.assert_called_once()
+
+ def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
+ """Test inviting a member who is already in the tenant."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
+ mock_existing_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="existing-user-456", email="existing@example.com", status="active"
+ )
+
+ # Mock database queries
+ query_results = {
+ ("Account", "email", "existing@example.com"): mock_existing_account,
+ (
+ "TenantAccountJoin",
+ "tenant_id",
+ "tenant-456",
+ ): TestAccountAssociatedDataFactory.create_tenant_join_mock(),
+ }
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Mock TenantService methods
+ with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission:
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ AccountAlreadyInTenantError,
+ RegisterService.invite_new_member,
+ tenant=mock_tenant,
+ email="existing@example.com",
+ language="en-US",
+ role="normal",
+ inviter=mock_inviter,
+ )
+
+ def test_invite_new_member_no_inviter(self):
+ """Test inviting a member without providing an inviter."""
+ # Setup test data
+ mock_tenant = MagicMock()
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ ValueError,
+ RegisterService.invite_new_member,
+ tenant=mock_tenant,
+ email="test@example.com",
+ language="en-US",
+ role="normal",
+ inviter=None,
+ )
+
+ # ==================== Token Management Tests ====================
+
+ def test_generate_invite_token_success(self, mock_redis_dependencies):
+ """Test successful invite token generation."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="user-123", email="test@example.com"
+ )
+
+ # Mock uuid generation
+ with patch("services.account_service.uuid.uuid4") as mock_uuid:
+ mock_uuid.return_value = "test-uuid-123"
+
+ # Execute test
+ result = RegisterService.generate_invite_token(mock_tenant, mock_account)
+
+ # Verify results
+ assert result == "test-uuid-123"
+ mock_redis_dependencies.setex.assert_called_once()
+
+ # Verify the stored data
+ call_args = mock_redis_dependencies.setex.call_args
+ assert call_args[0][0] == "member_invite:token:test-uuid-123"
+ stored_data = json.loads(call_args[0][2])
+ assert stored_data["account_id"] == "user-123"
+ assert stored_data["email"] == "test@example.com"
+ assert stored_data["workspace_id"] == "tenant-456"
+
+ def test_is_valid_invite_token_valid(self, mock_redis_dependencies):
+ """Test checking valid invite token."""
+ # Setup mock
+ mock_redis_dependencies.get.return_value = b'{"test": "data"}'
+
+ # Execute test
+ result = RegisterService.is_valid_invite_token("valid-token")
+
+ # Verify results
+ assert result is True
+ mock_redis_dependencies.get.assert_called_once_with("member_invite:token:valid-token")
+
+ def test_is_valid_invite_token_invalid(self, mock_redis_dependencies):
+ """Test checking invalid invite token."""
+ # Setup mock
+ mock_redis_dependencies.get.return_value = None
+
+ # Execute test
+ result = RegisterService.is_valid_invite_token("invalid-token")
+
+ # Verify results
+ assert result is False
+ mock_redis_dependencies.get.assert_called_once_with("member_invite:token:invalid-token")
+
+ def test_revoke_token_with_workspace_and_email(self, mock_redis_dependencies):
+ """Test revoking token with workspace ID and email."""
+ # Execute test
+ RegisterService.revoke_token("workspace-123", "test@example.com", "token-123")
+
+ # Verify results
+ mock_redis_dependencies.delete.assert_called_once()
+ call_args = mock_redis_dependencies.delete.call_args
+ assert "workspace-123" in call_args[0][0]
+ # The email is hashed, so we check for the hash pattern instead
+ assert "member_invite_token:" in call_args[0][0]
+
+ def test_revoke_token_without_workspace_and_email(self, mock_redis_dependencies):
+ """Test revoking token without workspace ID and email."""
+ # Execute test
+ RegisterService.revoke_token("", "", "token-123")
+
+ # Verify results
+ mock_redis_dependencies.delete.assert_called_once_with("member_invite:token:token-123")
+
+ # ==================== Invitation Validation Tests ====================
+
+ def test_get_invitation_if_token_valid_success(self, mock_db_dependencies, mock_redis_dependencies):
+ """Test successful invitation validation."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.status = "normal"
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="user-123", email="test@example.com"
+ )
+
+ with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token:
+ # Mock the invitation data returned by _get_invitation_by_token
+ invitation_data = {
+ "account_id": "user-123",
+ "email": "test@example.com",
+ "workspace_id": "tenant-456",
+ }
+ mock_get_invitation_by_token.return_value = invitation_data
+
+ # Mock database queries - complex query mocking
+ mock_query1 = MagicMock()
+ mock_query1.filter.return_value.first.return_value = mock_tenant
+
+ mock_query2 = MagicMock()
+ mock_query2.join.return_value.filter.return_value.first.return_value = (mock_account, "normal")
+
+ mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
+
+ # Execute test
+ result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
+
+ # Verify results
+ assert result is not None
+ assert result["account"] == mock_account
+ assert result["tenant"] == mock_tenant
+ assert result["data"] == invitation_data
+
+ def test_get_invitation_if_token_valid_no_token_data(self, mock_redis_dependencies):
+ """Test invitation validation with no token data."""
+ # Setup mock
+ mock_redis_dependencies.get.return_value = None
+
+ # Execute test
+ result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
+
+ # Verify results
+ assert result is None
+
+ def test_get_invitation_if_token_valid_tenant_not_found(self, mock_db_dependencies, mock_redis_dependencies):
+ """Test invitation validation when tenant is not found."""
+ # Setup mock Redis data
+ invitation_data = {
+ "account_id": "user-123",
+ "email": "test@example.com",
+ "workspace_id": "tenant-456",
+ }
+ mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
+
+ # Mock database queries - no tenant found
+ mock_query = MagicMock()
+ mock_query.filter.return_value.first.return_value = None
+ mock_db_dependencies["db"].session.query.return_value = mock_query
+
+ # Execute test
+ result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
+
+ # Verify results
+ assert result is None
+
+ def test_get_invitation_if_token_valid_account_not_found(self, mock_db_dependencies, mock_redis_dependencies):
+ """Test invitation validation when account is not found."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.status = "normal"
+
+ # Mock Redis data
+ invitation_data = {
+ "account_id": "user-123",
+ "email": "test@example.com",
+ "workspace_id": "tenant-456",
+ }
+ mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
+
+ # Mock database queries
+ mock_query1 = MagicMock()
+ mock_query1.filter.return_value.first.return_value = mock_tenant
+
+ mock_query2 = MagicMock()
+ mock_query2.join.return_value.filter.return_value.first.return_value = None # No account found
+
+ mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
+
+ # Execute test
+ result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
+
+ # Verify results
+ assert result is None
+
+ def test_get_invitation_if_token_valid_account_id_mismatch(self, mock_db_dependencies, mock_redis_dependencies):
+ """Test invitation validation when account ID doesn't match."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.status = "normal"
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="different-user-456", email="test@example.com"
+ )
+
+ # Mock Redis data with different account ID
+ invitation_data = {
+ "account_id": "user-123",
+ "email": "test@example.com",
+ "workspace_id": "tenant-456",
+ }
+ mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
+
+ # Mock database queries
+ mock_query1 = MagicMock()
+ mock_query1.filter.return_value.first.return_value = mock_tenant
+
+ mock_query2 = MagicMock()
+ mock_query2.join.return_value.filter.return_value.first.return_value = (mock_account, "normal")
+
+ mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
+
+ # Execute test
+ result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
+
+ # Verify results
+ assert result is None
+
+ # ==================== Helper Method Tests ====================
+
+ def test_get_invitation_token_key(self):
+ """Test the _get_invitation_token_key helper method."""
+ # Execute test
+ result = RegisterService._get_invitation_token_key("test-token")
+
+ # Verify results
+ assert result == "member_invite:token:test-token"
+
+ def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies):
+ """Test _get_invitation_by_token with workspace ID and email."""
+ # Setup mock
+ mock_redis_dependencies.get.return_value = b"user-123"
+
+ # Execute test
+ result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com")
+
+ # Verify results
+ assert result is not None
+ assert result["account_id"] == "user-123"
+ assert result["email"] == "test@example.com"
+ assert result["workspace_id"] == "workspace-456"
+
+ def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies):
+ """Test _get_invitation_by_token without workspace ID and email."""
+ # Setup mock
+ invitation_data = {
+ "account_id": "user-123",
+ "email": "test@example.com",
+ "workspace_id": "tenant-456",
+ }
+ mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
+
+ # Execute test
+ result = RegisterService._get_invitation_by_token("token-123")
+
+ # Verify results
+ assert result is not None
+ assert result == invitation_data
+
+ def test_get_invitation_by_token_no_data(self, mock_redis_dependencies):
+ """Test _get_invitation_by_token with no data."""
+ # Setup mock
+ mock_redis_dependencies.get.return_value = None
+
+ # Execute test
+ result = RegisterService._get_invitation_by_token("token-123")
+
+ # Verify results
+ assert result is None
diff --git a/api/tests/unit_tests/services/tools/__init__.py b/api/tests/unit_tests/services/tools/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/services/tools/test_tools_transform_service.py b/api/tests/unit_tests/services/tools/test_tools_transform_service.py
new file mode 100644
index 0000000000..549ad018e8
--- /dev/null
+++ b/api/tests/unit_tests/services/tools/test_tools_transform_service.py
@@ -0,0 +1,301 @@
+from unittest.mock import Mock
+
+from core.tools.__base.tool import Tool
+from core.tools.entities.api_entities import ToolApiEntity
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolParameter
+from services.tools.tools_transform_service import ToolTransformService
+
+
+class TestToolTransformService:
+ """Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
+
+ def test_convert_tool_with_parameter_override(self):
+ """Test that runtime parameters correctly override base parameters"""
+ # Create mock base parameters
+ base_param1 = Mock(spec=ToolParameter)
+ base_param1.name = "param1"
+ base_param1.form = ToolParameter.ToolParameterForm.FORM
+ base_param1.type = "string"
+ base_param1.label = "Base Param 1"
+
+ base_param2 = Mock(spec=ToolParameter)
+ base_param2.name = "param2"
+ base_param2.form = ToolParameter.ToolParameterForm.FORM
+ base_param2.type = "string"
+ base_param2.label = "Base Param 2"
+
+ # Create mock runtime parameters that override base parameters
+ runtime_param1 = Mock(spec=ToolParameter)
+ runtime_param1.name = "param1"
+ runtime_param1.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param1.type = "string"
+ runtime_param1.label = "Runtime Param 1" # Different label to verify override
+
+ # Create mock tool
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = [base_param1, base_param2]
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = [runtime_param1]
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.author == "test_author"
+ assert result.name == "test_tool"
+ assert result.parameters is not None
+ assert len(result.parameters) == 2
+
+ # Find the overridden parameter
+ overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
+ assert overridden_param is not None
+ assert overridden_param.label == "Runtime Param 1" # Should be runtime version
+
+ # Find the non-overridden parameter
+ original_param = next((p for p in result.parameters if p.name == "param2"), None)
+ assert original_param is not None
+ assert original_param.label == "Base Param 2" # Should be base version
+
+ def test_convert_tool_with_additional_runtime_parameters(self):
+ """Test that additional runtime parameters are added to the final list"""
+ # Create mock base parameters
+ base_param1 = Mock(spec=ToolParameter)
+ base_param1.name = "param1"
+ base_param1.form = ToolParameter.ToolParameterForm.FORM
+ base_param1.type = "string"
+ base_param1.label = "Base Param 1"
+
+ # Create mock runtime parameters - one that overrides and one that's new
+ runtime_param1 = Mock(spec=ToolParameter)
+ runtime_param1.name = "param1"
+ runtime_param1.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param1.type = "string"
+ runtime_param1.label = "Runtime Param 1"
+
+ runtime_param2 = Mock(spec=ToolParameter)
+ runtime_param2.name = "runtime_only"
+ runtime_param2.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param2.type = "string"
+ runtime_param2.label = "Runtime Only Param"
+
+ # Create mock tool
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = [base_param1]
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.parameters is not None
+ assert len(result.parameters) == 2
+
+ # Check that both parameters are present
+ param_names = [p.name for p in result.parameters]
+ assert "param1" in param_names
+ assert "runtime_only" in param_names
+
+ # Verify the overridden parameter has runtime version
+ overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
+ assert overridden_param is not None
+ assert overridden_param.label == "Runtime Param 1"
+
+ # Verify the new runtime parameter is included
+ new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
+ assert new_param is not None
+ assert new_param.label == "Runtime Only Param"
+
+ def test_convert_tool_with_non_form_runtime_parameters(self):
+ """Test that non-FORM runtime parameters are not added as new parameters"""
+ # Create mock base parameters
+ base_param1 = Mock(spec=ToolParameter)
+ base_param1.name = "param1"
+ base_param1.form = ToolParameter.ToolParameterForm.FORM
+ base_param1.type = "string"
+ base_param1.label = "Base Param 1"
+
+ # Create mock runtime parameters with different forms
+ runtime_param1 = Mock(spec=ToolParameter)
+ runtime_param1.name = "param1"
+ runtime_param1.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param1.type = "string"
+ runtime_param1.label = "Runtime Param 1"
+
+ runtime_param2 = Mock(spec=ToolParameter)
+ runtime_param2.name = "llm_param"
+ runtime_param2.form = ToolParameter.ToolParameterForm.LLM
+ runtime_param2.type = "string"
+ runtime_param2.label = "LLM Param"
+
+ # Create mock tool
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = [base_param1]
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.parameters is not None
+ assert len(result.parameters) == 1 # Only the FORM parameter should be present
+
+ # Check that only the FORM parameter is present
+ param_names = [p.name for p in result.parameters]
+ assert "param1" in param_names
+ assert "llm_param" not in param_names
+
+ def test_convert_tool_with_empty_parameters(self):
+ """Test conversion with empty base and runtime parameters"""
+ # Create mock tool with no parameters
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = []
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = []
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.parameters is not None
+ assert len(result.parameters) == 0
+
+ def test_convert_tool_with_none_parameters(self):
+ """Test conversion when base parameters is None"""
+ # Create mock tool with None parameters
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = None
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = []
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.parameters is not None
+ assert len(result.parameters) == 0
+
+ def test_convert_tool_parameter_order_preserved(self):
+ """Test that parameter order is preserved correctly"""
+ # Create mock base parameters in specific order
+ base_param1 = Mock(spec=ToolParameter)
+ base_param1.name = "param1"
+ base_param1.form = ToolParameter.ToolParameterForm.FORM
+ base_param1.type = "string"
+ base_param1.label = "Base Param 1"
+
+ base_param2 = Mock(spec=ToolParameter)
+ base_param2.name = "param2"
+ base_param2.form = ToolParameter.ToolParameterForm.FORM
+ base_param2.type = "string"
+ base_param2.label = "Base Param 2"
+
+ base_param3 = Mock(spec=ToolParameter)
+ base_param3.name = "param3"
+ base_param3.form = ToolParameter.ToolParameterForm.FORM
+ base_param3.type = "string"
+ base_param3.label = "Base Param 3"
+
+ # Create runtime parameter that overrides middle parameter
+ runtime_param2 = Mock(spec=ToolParameter)
+ runtime_param2.name = "param2"
+ runtime_param2.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param2.type = "string"
+ runtime_param2.label = "Runtime Param 2"
+
+ # Create new runtime parameter
+ runtime_param4 = Mock(spec=ToolParameter)
+ runtime_param4.name = "param4"
+ runtime_param4.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param4.type = "string"
+ runtime_param4.label = "Runtime Param 4"
+
+ # Create mock tool
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.parameters is not None
+ assert len(result.parameters) == 4
+
+ # Check that order is maintained: base parameters first, then new runtime parameters
+ param_names = [p.name for p in result.parameters]
+ assert param_names == ["param1", "param2", "param3", "param4"]
+
+ # Verify that param2 was overridden with runtime version
+ param2 = result.parameters[1]
+ assert param2.name == "param2"
+ assert param2.label == "Runtime Param 2"
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py
index 223020c2c5..2c87eaf805 100644
--- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py
+++ b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py
@@ -10,7 +10,8 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
@pytest.fixture
def workflow_setup():
- workflow_service = WorkflowService()
+ mock_session_maker = MagicMock()
+ workflow_service = WorkflowService(mock_session_maker)
session = MagicMock(spec=Session)
tenant_id = "test-tenant-id"
workflow_id = "test-workflow-id"
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py
index c5c9cf1050..8b1348b75b 100644
--- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py
+++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py
@@ -1,14 +1,14 @@
import dataclasses
import secrets
-from unittest import mock
-from unittest.mock import Mock, patch
+from unittest.mock import MagicMock, Mock, patch
import pytest
+from sqlalchemy import Engine
from sqlalchemy.orm import Session
from core.variables import StringSegment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
-from core.workflow.nodes import NodeType
+from core.workflow.nodes.enums import NodeType
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from services.workflow_draft_variable_service import (
@@ -18,13 +18,25 @@ from services.workflow_draft_variable_service import (
)
+@pytest.fixture
+def mock_engine() -> Engine:
+ return Mock(spec=Engine)
+
+
+@pytest.fixture
+def mock_session(mock_engine) -> Session:
+ mock_session = Mock(spec=Session)
+ mock_session.get_bind.return_value = mock_engine
+ return mock_session
+
+
class TestDraftVariableSaver:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test__should_variable_be_visible(self):
- mock_session = mock.MagicMock(spec=Session)
+ mock_session = MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
@@ -70,7 +82,7 @@ class TestDraftVariableSaver:
),
]
- mock_session = mock.MagicMock(spec=Session)
+ mock_session = MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
@@ -105,9 +117,8 @@ class TestWorkflowDraftVariableService:
conversation_variables=[],
)
- def test_reset_conversation_variable(self):
+ def test_reset_conversation_variable(self, mock_session):
"""Test resetting a conversation variable"""
- mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -131,9 +142,8 @@ class TestWorkflowDraftVariableService:
mock_reset_conv.assert_called_once_with(workflow, variable)
assert result == expected_result
- def test_reset_node_variable_with_no_execution_id(self):
+ def test_reset_node_variable_with_no_execution_id(self, mock_session):
"""Test resetting a node variable with no execution ID - should delete variable"""
- mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -158,11 +168,26 @@ class TestWorkflowDraftVariableService:
mock_session.flush.assert_called_once()
assert result is None
- def test_reset_node_variable_with_missing_execution_record(self):
+ def test_reset_node_variable_with_missing_execution_record(
+ self,
+ mock_engine,
+ mock_session,
+ monkeypatch,
+ ):
"""Test resetting a node variable when execution record doesn't exist"""
- mock_session = Mock(spec=Session)
+ mock_repo_session = Mock(spec=Session)
+
+ mock_session_maker = MagicMock()
+ # Mock the context manager protocol for sessionmaker
+ mock_session_maker.return_value.__enter__.return_value = mock_repo_session
+ mock_session_maker.return_value.__exit__.return_value = None
+ monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker)
service = WorkflowDraftVariableService(mock_session)
+ # Mock the repository to return None (no execution record found)
+ service._api_node_execution_repo = Mock()
+ service._api_node_execution_repo.get_execution_by_id.return_value = None
+
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@@ -171,24 +196,41 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
-
- # Mock session.scalars to return None (no execution record found)
- mock_scalars = Mock()
- mock_scalars.first.return_value = None
- mock_session.scalars.return_value = mock_scalars
+ # Variable is editable by default from factory method
result = service._reset_node_var_or_sys_var(workflow, variable)
+ mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=variable)
mock_session.flush.assert_called_once()
assert result is None
- def test_reset_node_variable_with_valid_execution_record(self):
+ def test_reset_node_variable_with_valid_execution_record(
+ self,
+ mock_session,
+ monkeypatch,
+ ):
"""Test resetting a node variable with valid execution record - should restore from execution"""
- mock_session = Mock(spec=Session)
+ mock_repo_session = Mock(spec=Session)
+
+ mock_session_maker = MagicMock()
+ # Mock the context manager protocol for sessionmaker
+ mock_session_maker.return_value.__enter__.return_value = mock_repo_session
+ mock_session_maker.return_value.__exit__.return_value = None
+ mock_session_maker = monkeypatch.setattr(
+ "services.workflow_draft_variable_service.sessionmaker", mock_session_maker
+ )
service = WorkflowDraftVariableService(mock_session)
+ # Create mock execution record
+ mock_execution = Mock(spec=WorkflowNodeExecutionModel)
+ mock_execution.outputs_dict = {"test_var": "output_value"}
+
+ # Mock the repository to return the execution record
+ service._api_node_execution_repo = Mock()
+ service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
+
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@@ -197,16 +239,7 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
-
- # Create mock execution record
- mock_execution = Mock(spec=WorkflowNodeExecutionModel)
- mock_execution.process_data_dict = {"test_var": "process_value"}
- mock_execution.outputs_dict = {"test_var": "output_value"}
-
- # Mock session.scalars to return the execution record
- mock_scalars = Mock()
- mock_scalars.first.return_value = mock_execution
- mock_session.scalars.return_value = mock_scalars
+ # Variable is editable by default from factory method
# Mock workflow methods
mock_node_config = {"type": "test_node"}
@@ -224,9 +257,8 @@ class TestWorkflowDraftVariableService:
# Should return the updated variable
assert result == variable
- def test_reset_non_editable_system_variable_raises_error(self):
+ def test_reset_non_editable_system_variable_raises_error(self, mock_session):
"""Test that resetting a non-editable system variable raises an error"""
- mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -242,24 +274,13 @@ class TestWorkflowDraftVariableService:
editable=False, # Non-editable system variable
)
- # Mock the service to properly check system variable editability
- with patch.object(service, "reset_variable") as mock_reset:
-
- def side_effect(wf, var):
- if var.get_variable_type() == DraftVariableType.SYS and not is_system_variable_editable(var.name):
- raise VariableResetError(f"cannot reset system variable, variable_id={var.id}")
- return var
-
- mock_reset.side_effect = side_effect
-
- with pytest.raises(VariableResetError) as exc_info:
- service.reset_variable(workflow, variable)
- assert "cannot reset system variable" in str(exc_info.value)
- assert f"variable_id={variable.id}" in str(exc_info.value)
+ with pytest.raises(VariableResetError) as exc_info:
+ service.reset_variable(workflow, variable)
+ assert "cannot reset system variable" in str(exc_info.value)
+ assert f"variable_id={variable.id}" in str(exc_info.value)
- def test_reset_editable_system_variable_succeeds(self):
+ def test_reset_editable_system_variable_succeeds(self, mock_session):
"""Test that resetting an editable system variable succeeds"""
- mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -279,10 +300,9 @@ class TestWorkflowDraftVariableService:
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.files": "[]"}
- # Mock session.scalars to return the execution record
- mock_scalars = Mock()
- mock_scalars.first.return_value = mock_execution
- mock_session.scalars.return_value = mock_scalars
+ # Mock the repository to return the execution record
+ service._api_node_execution_repo = Mock()
+ service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)
@@ -291,9 +311,8 @@ class TestWorkflowDraftVariableService:
assert variable.last_edited_at is None
mock_session.flush.assert_called()
- def test_reset_query_system_variable_succeeds(self):
+ def test_reset_query_system_variable_succeeds(self, mock_session):
"""Test that resetting query system variable (another editable one) succeeds"""
- mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -313,10 +332,9 @@ class TestWorkflowDraftVariableService:
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.query": "reset query"}
- # Mock session.scalars to return the execution record
- mock_scalars = Mock()
- mock_scalars.first.return_value = mock_execution
- mock_session.scalars.return_value = mock_scalars
+ # Mock the repository to return the execution record
+ service._api_node_execution_repo = Mock()
+ service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py
new file mode 100644
index 0000000000..32d2f8b7e0
--- /dev/null
+++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py
@@ -0,0 +1,288 @@
+from datetime import datetime
+from unittest.mock import MagicMock
+from uuid import uuid4
+
+import pytest
+from sqlalchemy.orm import Session
+
+from models.workflow import WorkflowNodeExecutionModel
+from repositories.sqlalchemy_api_workflow_node_execution_repository import (
+ DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
+)
+
+
+class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
+ @pytest.fixture
+ def repository(self):
+ mock_session_maker = MagicMock()
+ return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
+
+ @pytest.fixture
+ def mock_execution(self):
+ execution = MagicMock(spec=WorkflowNodeExecutionModel)
+ execution.id = str(uuid4())
+ execution.tenant_id = "tenant-123"
+ execution.app_id = "app-456"
+ execution.workflow_id = "workflow-789"
+ execution.workflow_run_id = "run-101"
+ execution.node_id = "node-202"
+ execution.index = 1
+ execution.created_at = "2023-01-01T00:00:00Z"
+ return execution
+
+ def test_get_node_last_execution_found(self, repository, mock_execution):
+ """Test getting the last execution for a node when it exists."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ mock_session.scalar.return_value = mock_execution
+
+ # Act
+ result = repository.get_node_last_execution(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ workflow_id="workflow-789",
+ node_id="node-202",
+ )
+
+ # Assert
+ assert result == mock_execution
+ mock_session.scalar.assert_called_once()
+ # Verify the query was constructed correctly
+ call_args = mock_session.scalar.call_args[0][0]
+ assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
+
+ def test_get_node_last_execution_not_found(self, repository):
+ """Test getting the last execution for a node when it doesn't exist."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ mock_session.scalar.return_value = None
+
+ # Act
+ result = repository.get_node_last_execution(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ workflow_id="workflow-789",
+ node_id="node-202",
+ )
+
+ # Assert
+ assert result is None
+ mock_session.scalar.assert_called_once()
+
+ def test_get_executions_by_workflow_run(self, repository, mock_execution):
+ """Test getting all executions for a workflow run."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ executions = [mock_execution]
+ mock_session.execute.return_value.scalars.return_value.all.return_value = executions
+
+ # Act
+ result = repository.get_executions_by_workflow_run(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ workflow_run_id="run-101",
+ )
+
+ # Assert
+ assert result == executions
+ mock_session.execute.assert_called_once()
+ # Verify the query was constructed correctly
+ call_args = mock_session.execute.call_args[0][0]
+ assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
+
+ def test_get_executions_by_workflow_run_empty(self, repository):
+ """Test getting executions for a workflow run when none exist."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ mock_session.execute.return_value.scalars.return_value.all.return_value = []
+
+ # Act
+ result = repository.get_executions_by_workflow_run(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ workflow_run_id="run-101",
+ )
+
+ # Assert
+ assert result == []
+ mock_session.execute.assert_called_once()
+
+ def test_get_execution_by_id_found(self, repository, mock_execution):
+ """Test getting execution by ID when it exists."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ mock_session.scalar.return_value = mock_execution
+
+ # Act
+ result = repository.get_execution_by_id(mock_execution.id)
+
+ # Assert
+ assert result == mock_execution
+ mock_session.scalar.assert_called_once()
+
+ def test_get_execution_by_id_not_found(self, repository):
+ """Test getting execution by ID when it doesn't exist."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ mock_session.scalar.return_value = None
+
+ # Act
+ result = repository.get_execution_by_id("non-existent-id")
+
+ # Assert
+ assert result is None
+ mock_session.scalar.assert_called_once()
+
+ def test_repository_implements_protocol(self, repository):
+ """Test that the repository implements the required protocol methods."""
+ # Verify all protocol methods are implemented
+ assert hasattr(repository, "get_node_last_execution")
+ assert hasattr(repository, "get_executions_by_workflow_run")
+ assert hasattr(repository, "get_execution_by_id")
+
+ # Verify methods are callable
+ assert callable(repository.get_node_last_execution)
+ assert callable(repository.get_executions_by_workflow_run)
+ assert callable(repository.get_execution_by_id)
+ assert callable(repository.delete_expired_executions)
+ assert callable(repository.delete_executions_by_app)
+ assert callable(repository.get_expired_executions_batch)
+ assert callable(repository.delete_executions_by_ids)
+
+ def test_delete_expired_executions(self, repository):
+ """Test deleting expired executions."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+
+ # Mock the select query to return some IDs first time, then empty to stop loop
+ execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
+
+ # Mock execute method to handle both select and delete statements
+ def mock_execute(stmt):
+ mock_result = MagicMock()
+ # For select statements, return execution IDs
+ if hasattr(stmt, "limit"): # This is our select statement
+ mock_result.scalars.return_value.all.return_value = execution_ids
+ else: # This is our delete statement
+ mock_result.rowcount = 2
+ return mock_result
+
+ mock_session.execute.side_effect = mock_execute
+
+ before_date = datetime(2023, 1, 1)
+
+ # Act
+ result = repository.delete_expired_executions(
+ tenant_id="tenant-123",
+ before_date=before_date,
+ batch_size=1000,
+ )
+
+ # Assert
+ assert result == 2
+ assert mock_session.execute.call_count == 2 # One select call, one delete call
+ mock_session.commit.assert_called_once()
+
+ def test_delete_executions_by_app(self, repository):
+ """Test deleting executions by app."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+
+ # Mock the select query to return some IDs first time, then empty to stop loop
+ execution_ids = ["id1", "id2"]
+
+ # Mock execute method to handle both select and delete statements
+ def mock_execute(stmt):
+ mock_result = MagicMock()
+ # For select statements, return execution IDs
+ if hasattr(stmt, "limit"): # This is our select statement
+ mock_result.scalars.return_value.all.return_value = execution_ids
+ else: # This is our delete statement
+ mock_result.rowcount = 2
+ return mock_result
+
+ mock_session.execute.side_effect = mock_execute
+
+ # Act
+ result = repository.delete_executions_by_app(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ batch_size=1000,
+ )
+
+ # Assert
+ assert result == 2
+ assert mock_session.execute.call_count == 2 # One select call, one delete call
+ mock_session.commit.assert_called_once()
+
+ def test_get_expired_executions_batch(self, repository):
+ """Test getting expired executions batch for backup."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+
+ # Create mock execution objects
+ mock_execution1 = MagicMock()
+ mock_execution1.id = "exec-1"
+ mock_execution2 = MagicMock()
+ mock_execution2.id = "exec-2"
+
+ mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
+
+ before_date = datetime(2023, 1, 1)
+
+ # Act
+ result = repository.get_expired_executions_batch(
+ tenant_id="tenant-123",
+ before_date=before_date,
+ batch_size=1000,
+ )
+
+ # Assert
+ assert len(result) == 2
+ assert result[0].id == "exec-1"
+ assert result[1].id == "exec-2"
+ mock_session.execute.assert_called_once()
+
+ def test_delete_executions_by_ids(self, repository):
+ """Test deleting executions by IDs."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+
+ # Mock the delete query result
+ mock_result = MagicMock()
+ mock_result.rowcount = 3
+ mock_session.execute.return_value = mock_result
+
+ execution_ids = ["id1", "id2", "id3"]
+
+ # Act
+ result = repository.delete_executions_by_ids(execution_ids)
+
+ # Assert
+ assert result == 3
+ mock_session.execute.assert_called_once()
+ mock_session.commit.assert_called_once()
+
+ def test_delete_executions_by_ids_empty_list(self, repository):
+ """Test deleting executions with empty ID list."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+
+ # Act
+ result = repository.delete_executions_by_ids([])
+
+ # Assert
+ assert result == 0
+ mock_session.query.assert_not_called()
+ mock_session.commit.assert_not_called()
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py
index 13393668ea..9700cbaf0e 100644
--- a/api/tests/unit_tests/services/workflow/test_workflow_service.py
+++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py
@@ -10,7 +10,8 @@ from services.workflow_service import WorkflowService
class TestWorkflowService:
@pytest.fixture
def workflow_service(self):
- return WorkflowService()
+ mock_session_maker = MagicMock()
+ return WorkflowService(mock_session_maker)
@pytest.fixture
def mock_app(self):
diff --git a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py
new file mode 100644
index 0000000000..30990f8d50
--- /dev/null
+++ b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py
@@ -0,0 +1,619 @@
+import base64
+import hashlib
+from unittest.mock import patch
+
+import pytest
+from Crypto.Cipher import AES
+from Crypto.Random import get_random_bytes
+from Crypto.Util.Padding import pad
+
+from core.tools.utils.system_oauth_encryption import (
+ OAuthEncryptionError,
+ SystemOAuthEncrypter,
+ create_system_oauth_encrypter,
+ decrypt_system_oauth_params,
+ encrypt_system_oauth_params,
+ get_system_oauth_encrypter,
+)
+
+
+class TestSystemOAuthEncrypter:
+ """Test cases for SystemOAuthEncrypter class"""
+
+ def test_init_with_secret_key(self):
+ """Test initialization with provided secret key"""
+ secret_key = "test_secret_key"
+ encrypter = SystemOAuthEncrypter(secret_key=secret_key)
+ expected_key = hashlib.sha256(secret_key.encode()).digest()
+ assert encrypter.key == expected_key
+
+ def test_init_with_none_secret_key(self):
+ """Test initialization with None secret key falls back to config"""
+ with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "config_secret"
+ encrypter = SystemOAuthEncrypter(secret_key=None)
+ expected_key = hashlib.sha256(b"config_secret").digest()
+ assert encrypter.key == expected_key
+
+ def test_init_with_empty_secret_key(self):
+ """Test initialization with empty secret key"""
+ encrypter = SystemOAuthEncrypter(secret_key="")
+ expected_key = hashlib.sha256(b"").digest()
+ assert encrypter.key == expected_key
+
+ def test_init_without_secret_key_uses_config(self):
+ """Test initialization without secret key uses config"""
+ with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "default_secret"
+ encrypter = SystemOAuthEncrypter()
+ expected_key = hashlib.sha256(b"default_secret").digest()
+ assert encrypter.key == expected_key
+
+ def test_encrypt_oauth_params_basic(self):
+ """Test basic OAuth parameters encryption"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+ # Should be valid base64
+ try:
+ base64.b64decode(encrypted)
+ except Exception:
+ pytest.fail("Encrypted result is not valid base64")
+
+ def test_encrypt_oauth_params_empty_dict(self):
+ """Test encryption with empty dictionary"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {}
+
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+
+ def test_encrypt_oauth_params_complex_data(self):
+ """Test encryption with complex data structures"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {
+ "client_id": "test_id",
+ "client_secret": "test_secret",
+ "scopes": ["read", "write", "admin"],
+ "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
+ "numeric_value": 42,
+ "boolean_value": False,
+ "null_value": None,
+ }
+
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+
+ def test_encrypt_oauth_params_unicode_data(self):
+ """Test encryption with unicode data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
+
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+
+ def test_encrypt_oauth_params_large_data(self):
+ """Test encryption with large data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {
+ "client_id": "test_id",
+ "large_data": "x" * 10000, # 10KB of data
+ }
+
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+
+ def test_encrypt_oauth_params_invalid_input(self):
+ """Test encryption with invalid input types"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ with pytest.raises(Exception): # noqa: B017
+ encrypter.encrypt_oauth_params(None) # type: ignore
+
+ with pytest.raises(Exception): # noqa: B017
+ encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
+
+ def test_decrypt_oauth_params_basic(self):
+ """Test basic OAuth parameters decryption"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ original_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+
+ assert decrypted == original_params
+
+ def test_decrypt_oauth_params_empty_dict(self):
+ """Test decryption of empty dictionary"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ original_params = {}
+
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+
+ assert decrypted == original_params
+
+ def test_decrypt_oauth_params_complex_data(self):
+ """Test decryption with complex data structures"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ original_params = {
+ "client_id": "test_id",
+ "client_secret": "test_secret",
+ "scopes": ["read", "write", "admin"],
+ "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
+ "numeric_value": 42,
+ "boolean_value": False,
+ "null_value": None,
+ }
+
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+
+ assert decrypted == original_params
+
+ def test_decrypt_oauth_params_unicode_data(self):
+ """Test decryption with unicode data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ original_params = {
+ "client_id": "test_id",
+ "client_secret": "test_secret",
+ "description": "This is a test case 🚀",
+ }
+
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+
+ assert decrypted == original_params
+
+ def test_decrypt_oauth_params_large_data(self):
+ """Test decryption with large data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ original_params = {
+ "client_id": "test_id",
+ "large_data": "x" * 10000, # 10KB of data
+ }
+
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+
+ assert decrypted == original_params
+
+ def test_decrypt_oauth_params_invalid_base64(self):
+ """Test decryption with invalid base64 data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ with pytest.raises(OAuthEncryptionError):
+ encrypter.decrypt_oauth_params("invalid_base64!")
+
+ def test_decrypt_oauth_params_empty_string(self):
+ """Test decryption with empty string"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ with pytest.raises(ValueError) as exc_info:
+ encrypter.decrypt_oauth_params("")
+
+ assert "encrypted_data cannot be empty" in str(exc_info.value)
+
+ def test_decrypt_oauth_params_non_string_input(self):
+ """Test decryption with non-string input"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ with pytest.raises(ValueError) as exc_info:
+ encrypter.decrypt_oauth_params(123) # type: ignore
+
+ assert "encrypted_data must be a string" in str(exc_info.value)
+
+ with pytest.raises(ValueError) as exc_info:
+ encrypter.decrypt_oauth_params(None) # type: ignore
+
+ assert "encrypted_data must be a string" in str(exc_info.value)
+
+ def test_decrypt_oauth_params_too_short_data(self):
+ """Test decryption with too short encrypted data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ # Create data that's too short (less than 32 bytes)
+ short_data = base64.b64encode(b"short").decode()
+
+ with pytest.raises(OAuthEncryptionError) as exc_info:
+ encrypter.decrypt_oauth_params(short_data)
+
+ assert "Invalid encrypted data format" in str(exc_info.value)
+
+ def test_decrypt_oauth_params_corrupted_data(self):
+ """Test decryption with corrupted data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ # Create corrupted data (valid base64 but invalid encrypted content)
+ corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
+
+ with pytest.raises(OAuthEncryptionError):
+ encrypter.decrypt_oauth_params(corrupted_data)
+
+ def test_decrypt_oauth_params_wrong_key(self):
+ """Test decryption with wrong key"""
+ encrypter1 = SystemOAuthEncrypter("secret1")
+ encrypter2 = SystemOAuthEncrypter("secret2")
+
+ original_params = {"client_id": "test_id", "client_secret": "test_secret"}
+ encrypted = encrypter1.encrypt_oauth_params(original_params)
+
+ with pytest.raises(OAuthEncryptionError):
+ encrypter2.decrypt_oauth_params(encrypted)
+
+ def test_encryption_decryption_consistency(self):
+ """Test that encryption and decryption are consistent"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ test_cases = [
+ {},
+ {"simple": "value"},
+ {"client_id": "id", "client_secret": "secret"},
+ {"complex": {"nested": {"deep": "value"}}},
+ {"unicode": "test 🚀"},
+ {"numbers": 42, "boolean": True, "null": None},
+ {"array": [1, 2, 3, "four", {"five": 5}]},
+ ]
+
+ for original_params in test_cases:
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == original_params, f"Failed for case: {original_params}"
+
+ def test_encryption_randomness(self):
+ """Test that encryption produces different results for same input"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted1 = encrypter.encrypt_oauth_params(oauth_params)
+ encrypted2 = encrypter.encrypt_oauth_params(oauth_params)
+
+ # Should be different due to random IV
+ assert encrypted1 != encrypted2
+
+ # But should decrypt to same result
+ decrypted1 = encrypter.decrypt_oauth_params(encrypted1)
+ decrypted2 = encrypter.decrypt_oauth_params(encrypted2)
+ assert decrypted1 == decrypted2 == oauth_params
+
+ def test_different_secret_keys_produce_different_results(self):
+ """Test that different secret keys produce different encrypted results"""
+ encrypter1 = SystemOAuthEncrypter("secret1")
+ encrypter2 = SystemOAuthEncrypter("secret2")
+
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted1 = encrypter1.encrypt_oauth_params(oauth_params)
+ encrypted2 = encrypter2.encrypt_oauth_params(oauth_params)
+
+ # Should produce different encrypted results
+ assert encrypted1 != encrypted2
+
+ # But each should decrypt correctly with its own key
+ decrypted1 = encrypter1.decrypt_oauth_params(encrypted1)
+ decrypted2 = encrypter2.decrypt_oauth_params(encrypted2)
+ assert decrypted1 == decrypted2 == oauth_params
+
+ @patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
+ def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes):
+ """Test encryption when crypto operation fails"""
+ mock_get_random_bytes.side_effect = Exception("Crypto error")
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id"}
+
+ with pytest.raises(OAuthEncryptionError) as exc_info:
+ encrypter.encrypt_oauth_params(oauth_params)
+
+ assert "Encryption failed" in str(exc_info.value)
+
+ @patch("core.tools.utils.system_oauth_encryption.TypeAdapter")
+ def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter):
+ """Test encryption when JSON serialization fails"""
+ mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id"}
+
+ with pytest.raises(OAuthEncryptionError) as exc_info:
+ encrypter.encrypt_oauth_params(oauth_params)
+
+ assert "Encryption failed" in str(exc_info.value)
+
+ def test_decrypt_oauth_params_invalid_json(self):
+ """Test decryption with invalid JSON data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ # Create valid encrypted data but with invalid JSON content
+ iv = get_random_bytes(16)
+ cipher = AES.new(encrypter.key, AES.MODE_CBC, iv)
+ invalid_json = b"invalid json content"
+ padded_data = pad(invalid_json, AES.block_size)
+ encrypted_data = cipher.encrypt(padded_data)
+ combined = iv + encrypted_data
+ encoded = base64.b64encode(combined).decode()
+
+ with pytest.raises(OAuthEncryptionError):
+ encrypter.decrypt_oauth_params(encoded)
+
+ def test_key_derivation_consistency(self):
+ """Test that key derivation is consistent"""
+ secret_key = "test_secret"
+ encrypter1 = SystemOAuthEncrypter(secret_key)
+ encrypter2 = SystemOAuthEncrypter(secret_key)
+
+ assert encrypter1.key == encrypter2.key
+
+ # Keys should be 32 bytes (256 bits)
+ assert len(encrypter1.key) == 32
+
+
+class TestFactoryFunctions:
+ """Test cases for factory functions"""
+
+ def test_create_system_oauth_encrypter_with_secret(self):
+ """Test factory function with secret key"""
+ secret_key = "test_secret"
+ encrypter = create_system_oauth_encrypter(secret_key)
+
+ assert isinstance(encrypter, SystemOAuthEncrypter)
+ expected_key = hashlib.sha256(secret_key.encode()).digest()
+ assert encrypter.key == expected_key
+
+ def test_create_system_oauth_encrypter_without_secret(self):
+ """Test factory function without secret key"""
+ with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "config_secret"
+ encrypter = create_system_oauth_encrypter()
+
+ assert isinstance(encrypter, SystemOAuthEncrypter)
+ expected_key = hashlib.sha256(b"config_secret").digest()
+ assert encrypter.key == expected_key
+
+ def test_create_system_oauth_encrypter_with_none_secret(self):
+ """Test factory function with None secret key"""
+ with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "config_secret"
+ encrypter = create_system_oauth_encrypter(None)
+
+ assert isinstance(encrypter, SystemOAuthEncrypter)
+ expected_key = hashlib.sha256(b"config_secret").digest()
+ assert encrypter.key == expected_key
+
+
+class TestGlobalEncrypterInstance:
+ """Test cases for global encrypter instance"""
+
+ def test_get_system_oauth_encrypter_singleton(self):
+ """Test that get_system_oauth_encrypter returns singleton instance"""
+ # Clear the global instance first
+ import core.tools.utils.system_oauth_encryption
+
+ core.tools.utils.system_oauth_encryption._oauth_encrypter = None
+
+ encrypter1 = get_system_oauth_encrypter()
+ encrypter2 = get_system_oauth_encrypter()
+
+ assert encrypter1 is encrypter2
+ assert isinstance(encrypter1, SystemOAuthEncrypter)
+
+ def test_get_system_oauth_encrypter_uses_config(self):
+ """Test that global encrypter uses config"""
+ # Clear the global instance first
+ import core.tools.utils.system_oauth_encryption
+
+ core.tools.utils.system_oauth_encryption._oauth_encrypter = None
+
+ with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "global_secret"
+ encrypter = get_system_oauth_encrypter()
+
+ expected_key = hashlib.sha256(b"global_secret").digest()
+ assert encrypter.key == expected_key
+
+
+class TestConvenienceFunctions:
+ """Test cases for convenience functions"""
+
+ def test_encrypt_system_oauth_params(self):
+ """Test encrypt_system_oauth_params convenience function"""
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted = encrypt_system_oauth_params(oauth_params)
+
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+
+ def test_decrypt_system_oauth_params(self):
+ """Test decrypt_system_oauth_params convenience function"""
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted = encrypt_system_oauth_params(oauth_params)
+ decrypted = decrypt_system_oauth_params(encrypted)
+
+ assert decrypted == oauth_params
+
+ def test_convenience_functions_consistency(self):
+ """Test that convenience functions work consistently"""
+ test_cases = [
+ {},
+ {"simple": "value"},
+ {"client_id": "id", "client_secret": "secret"},
+ {"complex": {"nested": {"deep": "value"}}},
+ {"unicode": "test 🚀"},
+ {"numbers": 42, "boolean": True, "null": None},
+ ]
+
+ for original_params in test_cases:
+ encrypted = encrypt_system_oauth_params(original_params)
+ decrypted = decrypt_system_oauth_params(encrypted)
+ assert decrypted == original_params, f"Failed for case: {original_params}"
+
+ def test_convenience_functions_with_errors(self):
+ """Test convenience functions with error conditions"""
+ # Test encryption with invalid input
+ with pytest.raises(Exception): # noqa: B017
+ encrypt_system_oauth_params(None) # type: ignore
+
+ # Test decryption with invalid input
+ with pytest.raises(ValueError):
+ decrypt_system_oauth_params("")
+
+ with pytest.raises(ValueError):
+ decrypt_system_oauth_params(None) # type: ignore
+
+
+class TestErrorHandling:
+ """Test cases for error handling"""
+
+ def test_oauth_encryption_error_inheritance(self):
+ """Test that OAuthEncryptionError is a proper exception"""
+ error = OAuthEncryptionError("Test error")
+ assert isinstance(error, Exception)
+ assert str(error) == "Test error"
+
+ def test_oauth_encryption_error_with_cause(self):
+ """Test OAuthEncryptionError with cause"""
+ original_error = ValueError("Original error")
+ error = OAuthEncryptionError("Wrapper error")
+ error.__cause__ = original_error
+
+ assert isinstance(error, Exception)
+ assert str(error) == "Wrapper error"
+ assert error.__cause__ is original_error
+
+ def test_error_messages_are_informative(self):
+ """Test that error messages are informative"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ # Test empty string error
+ with pytest.raises(ValueError) as exc_info:
+ encrypter.decrypt_oauth_params("")
+ assert "encrypted_data cannot be empty" in str(exc_info.value)
+
+ # Test non-string error
+ with pytest.raises(ValueError) as exc_info:
+ encrypter.decrypt_oauth_params(123) # type: ignore
+ assert "encrypted_data must be a string" in str(exc_info.value)
+
+ # Test invalid format error
+ short_data = base64.b64encode(b"short").decode()
+ with pytest.raises(OAuthEncryptionError) as exc_info:
+ encrypter.decrypt_oauth_params(short_data)
+ assert "Invalid encrypted data format" in str(exc_info.value)
+
+
+class TestEdgeCases:
+ """Test cases for edge cases and boundary conditions"""
+
+ def test_very_long_secret_key(self):
+ """Test with very long secret key"""
+ long_secret = "x" * 10000
+ encrypter = SystemOAuthEncrypter(long_secret)
+
+ # Key should still be 32 bytes due to SHA-256
+ assert len(encrypter.key) == 32
+
+ # Should still work normally
+ oauth_params = {"client_id": "test_id"}
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_special_characters_in_secret_key(self):
+ """Test with special characters in secret key"""
+ special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
+ encrypter = SystemOAuthEncrypter(special_secret)
+
+ oauth_params = {"client_id": "test_id"}
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_empty_values_in_oauth_params(self):
+ """Test with empty values in oauth params"""
+ oauth_params = {
+ "client_id": "",
+ "client_secret": "",
+ "empty_dict": {},
+ "empty_list": [],
+ "empty_string": "",
+ "zero": 0,
+ "false": False,
+ "none": None,
+ }
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_deeply_nested_oauth_params(self):
+ """Test with deeply nested oauth params"""
+ oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_oauth_params_with_all_json_types(self):
+ """Test with all JSON-supported data types"""
+ oauth_params = {
+ "string": "test_string",
+ "integer": 42,
+ "float": 3.14159,
+ "boolean_true": True,
+ "boolean_false": False,
+ "null_value": None,
+ "empty_string": "",
+ "array": [1, "two", 3.0, True, False, None],
+ "object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
+ }
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+
+class TestPerformance:
+ """Test cases for performance considerations"""
+
+ def test_large_oauth_params(self):
+ """Test with large oauth params"""
+ large_value = "x" * 100000 # 100KB
+ oauth_params = {"client_id": "test_id", "large_data": large_value}
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_many_fields_oauth_params(self):
+ """Test with many fields in oauth params"""
+ oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)}
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_repeated_encryption_decryption(self):
+ """Test repeated encryption and decryption operations"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ # Test multiple rounds of encryption/decryption
+ for i in range(100):
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py
index 728c58fc5b..93284eed4b 100644
--- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py
+++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py
@@ -27,11 +27,11 @@ def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LL
return LLMUsage(
prompt_tokens=prompt_tokens,
prompt_unit_price=Decimal("0.001"),
- prompt_price_unit=Decimal("1"),
+ prompt_price_unit=Decimal(1),
prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"),
completion_tokens=completion_tokens,
completion_unit_price=Decimal("0.002"),
- completion_price_unit=Decimal("1"),
+ completion_price_unit=Decimal(1),
completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"),
total_tokens=prompt_tokens + completion_tokens,
total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"),
diff --git a/api/uv.lock b/api/uv.lock
index e108e0c445..21b6b20f53 100644
--- a/api/uv.lock
+++ b/api/uv.lock
@@ -1498,7 +1498,7 @@ dev = [
{ name = "pytest-cov", specifier = "~=4.1.0" },
{ name = "pytest-env", specifier = "~=1.1.3" },
{ name = "pytest-mock", specifier = "~=3.14.0" },
- { name = "ruff", specifier = "~=0.11.5" },
+ { name = "ruff", specifier = "~=0.12.3" },
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
{ name = "types-aiofiles", specifier = "~=24.1.0" },
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
@@ -5088,27 +5088,27 @@ wheels = [
[[package]]
name = "ruff"
-version = "0.11.13"
-source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/ed/da/9c6f995903b4d9474b39da91d2d626659af3ff1eeb43e9ae7c119349dba6/ruff-0.11.13.tar.gz", hash = "sha256:26fa247dc68d1d4e72c179e08889a25ac0c7ba4d78aecfc835d49cbfd60bf514", size = 4282054, upload-time = "2025-06-05T21:00:15.721Z" }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/7d/ce/a11d381192966e0b4290842cc8d4fac7dc9214ddf627c11c1afff87da29b/ruff-0.11.13-py3-none-linux_armv6l.whl", hash = "sha256:4bdfbf1240533f40042ec00c9e09a3aade6f8c10b6414cf11b519488d2635d46", size = 10292516, upload-time = "2025-06-05T20:59:32.944Z" },
- { url = "https://files.pythonhosted.org/packages/78/db/87c3b59b0d4e753e40b6a3b4a2642dfd1dcaefbff121ddc64d6c8b47ba00/ruff-0.11.13-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aef9c9ed1b5ca28bb15c7eac83b8670cf3b20b478195bd49c8d756ba0a36cf48", size = 11106083, upload-time = "2025-06-05T20:59:37.03Z" },
- { url = "https://files.pythonhosted.org/packages/77/79/d8cec175856ff810a19825d09ce700265f905c643c69f45d2b737e4a470a/ruff-0.11.13-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53b15a9dfdce029c842e9a5aebc3855e9ab7771395979ff85b7c1dedb53ddc2b", size = 10436024, upload-time = "2025-06-05T20:59:39.741Z" },
- { url = "https://files.pythonhosted.org/packages/8b/5b/f6d94f2980fa1ee854b41568368a2e1252681b9238ab2895e133d303538f/ruff-0.11.13-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab153241400789138d13f362c43f7edecc0edfffce2afa6a68434000ecd8f69a", size = 10646324, upload-time = "2025-06-05T20:59:42.185Z" },
- { url = "https://files.pythonhosted.org/packages/6c/9c/b4c2acf24ea4426016d511dfdc787f4ce1ceb835f3c5fbdbcb32b1c63bda/ruff-0.11.13-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c51f93029d54a910d3d24f7dd0bb909e31b6cd989a5e4ac513f4eb41629f0dc", size = 10174416, upload-time = "2025-06-05T20:59:44.319Z" },
- { url = "https://files.pythonhosted.org/packages/f3/10/e2e62f77c65ede8cd032c2ca39c41f48feabedb6e282bfd6073d81bb671d/ruff-0.11.13-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1808b3ed53e1a777c2ef733aca9051dc9bf7c99b26ece15cb59a0320fbdbd629", size = 11724197, upload-time = "2025-06-05T20:59:46.935Z" },
- { url = "https://files.pythonhosted.org/packages/bb/f0/466fe8469b85c561e081d798c45f8a1d21e0b4a5ef795a1d7f1a9a9ec182/ruff-0.11.13-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d28ce58b5ecf0f43c1b71edffabe6ed7f245d5336b17805803312ec9bc665933", size = 12511615, upload-time = "2025-06-05T20:59:49.534Z" },
- { url = "https://files.pythonhosted.org/packages/17/0e/cefe778b46dbd0cbcb03a839946c8f80a06f7968eb298aa4d1a4293f3448/ruff-0.11.13-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55e4bc3a77842da33c16d55b32c6cac1ec5fb0fbec9c8c513bdce76c4f922165", size = 12117080, upload-time = "2025-06-05T20:59:51.654Z" },
- { url = "https://files.pythonhosted.org/packages/5d/2c/caaeda564cbe103bed145ea557cb86795b18651b0f6b3ff6a10e84e5a33f/ruff-0.11.13-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:633bf2c6f35678c56ec73189ba6fa19ff1c5e4807a78bf60ef487b9dd272cc71", size = 11326315, upload-time = "2025-06-05T20:59:54.469Z" },
- { url = "https://files.pythonhosted.org/packages/75/f0/782e7d681d660eda8c536962920c41309e6dd4ebcea9a2714ed5127d44bd/ruff-0.11.13-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ffbc82d70424b275b089166310448051afdc6e914fdab90e08df66c43bb5ca9", size = 11555640, upload-time = "2025-06-05T20:59:56.986Z" },
- { url = "https://files.pythonhosted.org/packages/5d/d4/3d580c616316c7f07fb3c99dbecfe01fbaea7b6fd9a82b801e72e5de742a/ruff-0.11.13-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a9ddd3ec62a9a89578c85842b836e4ac832d4a2e0bfaad3b02243f930ceafcc", size = 10507364, upload-time = "2025-06-05T20:59:59.154Z" },
- { url = "https://files.pythonhosted.org/packages/5a/dc/195e6f17d7b3ea6b12dc4f3e9de575db7983db187c378d44606e5d503319/ruff-0.11.13-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d237a496e0778d719efb05058c64d28b757c77824e04ffe8796c7436e26712b7", size = 10141462, upload-time = "2025-06-05T21:00:01.481Z" },
- { url = "https://files.pythonhosted.org/packages/f4/8e/39a094af6967faa57ecdeacb91bedfb232474ff8c3d20f16a5514e6b3534/ruff-0.11.13-py3-none-musllinux_1_2_i686.whl", hash = "sha256:26816a218ca6ef02142343fd24c70f7cd8c5aa6c203bca284407adf675984432", size = 11121028, upload-time = "2025-06-05T21:00:04.06Z" },
- { url = "https://files.pythonhosted.org/packages/5a/c0/b0b508193b0e8a1654ec683ebab18d309861f8bd64e3a2f9648b80d392cb/ruff-0.11.13-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:51c3f95abd9331dc5b87c47ac7f376db5616041173826dfd556cfe3d4977f492", size = 11602992, upload-time = "2025-06-05T21:00:06.249Z" },
- { url = "https://files.pythonhosted.org/packages/7c/91/263e33ab93ab09ca06ce4f8f8547a858cc198072f873ebc9be7466790bae/ruff-0.11.13-py3-none-win32.whl", hash = "sha256:96c27935418e4e8e77a26bb05962817f28b8ef3843a6c6cc49d8783b5507f250", size = 10474944, upload-time = "2025-06-05T21:00:08.459Z" },
- { url = "https://files.pythonhosted.org/packages/46/f4/7c27734ac2073aae8efb0119cae6931b6fb48017adf048fdf85c19337afc/ruff-0.11.13-py3-none-win_amd64.whl", hash = "sha256:29c3189895a8a6a657b7af4e97d330c8a3afd2c9c8f46c81e2fc5a31866517e3", size = 11548669, upload-time = "2025-06-05T21:00:11.147Z" },
- { url = "https://files.pythonhosted.org/packages/ec/bf/b273dd11673fed8a6bd46032c0ea2a04b2ac9bfa9c628756a5856ba113b0/ruff-0.11.13-py3-none-win_arm64.whl", hash = "sha256:b4385285e9179d608ff1d2fb9922062663c658605819a6876d8beef0c30b7f3b", size = 10683928, upload-time = "2025-06-05T21:00:13.758Z" },
+version = "0.12.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/c3/2a/43955b530c49684d3c38fcda18c43caf91e99204c2a065552528e0552d4f/ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77", size = 4459341, upload-time = "2025-07-11T13:21:16.086Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e2/fd/b44c5115539de0d598d75232a1cc7201430b6891808df111b8b0506aae43/ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2", size = 10430499, upload-time = "2025-07-11T13:20:26.321Z" },
+ { url = "https://files.pythonhosted.org/packages/43/c5/9eba4f337970d7f639a37077be067e4ec80a2ad359e4cc6c5b56805cbc66/ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041", size = 11213413, upload-time = "2025-07-11T13:20:30.017Z" },
+ { url = "https://files.pythonhosted.org/packages/e2/2c/fac3016236cf1fe0bdc8e5de4f24c76ce53c6dd9b5f350d902549b7719b2/ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882", size = 10586941, upload-time = "2025-07-11T13:20:33.046Z" },
+ { url = "https://files.pythonhosted.org/packages/c5/0f/41fec224e9dfa49a139f0b402ad6f5d53696ba1800e0f77b279d55210ca9/ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901", size = 10783001, upload-time = "2025-07-11T13:20:35.534Z" },
+ { url = "https://files.pythonhosted.org/packages/0d/ca/dd64a9ce56d9ed6cad109606ac014860b1c217c883e93bf61536400ba107/ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0", size = 10269641, upload-time = "2025-07-11T13:20:38.459Z" },
+ { url = "https://files.pythonhosted.org/packages/63/5c/2be545034c6bd5ce5bb740ced3e7014d7916f4c445974be11d2a406d5088/ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6", size = 11875059, upload-time = "2025-07-11T13:20:41.517Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/d4/a74ef1e801ceb5855e9527dae105eaff136afcb9cc4d2056d44feb0e4792/ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc", size = 12658890, upload-time = "2025-07-11T13:20:44.442Z" },
+ { url = "https://files.pythonhosted.org/packages/13/c8/1057916416de02e6d7c9bcd550868a49b72df94e3cca0aeb77457dcd9644/ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687", size = 12232008, upload-time = "2025-07-11T13:20:47.374Z" },
+ { url = "https://files.pythonhosted.org/packages/f5/59/4f7c130cc25220392051fadfe15f63ed70001487eca21d1796db46cbcc04/ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e", size = 11499096, upload-time = "2025-07-11T13:20:50.348Z" },
+ { url = "https://files.pythonhosted.org/packages/d4/01/a0ad24a5d2ed6be03a312e30d32d4e3904bfdbc1cdbe63c47be9d0e82c79/ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311", size = 11688307, upload-time = "2025-07-11T13:20:52.945Z" },
+ { url = "https://files.pythonhosted.org/packages/93/72/08f9e826085b1f57c9a0226e48acb27643ff19b61516a34c6cab9d6ff3fa/ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07", size = 10661020, upload-time = "2025-07-11T13:20:55.799Z" },
+ { url = "https://files.pythonhosted.org/packages/80/a0/68da1250d12893466c78e54b4a0ff381370a33d848804bb51279367fc688/ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12", size = 10246300, upload-time = "2025-07-11T13:20:58.222Z" },
+ { url = "https://files.pythonhosted.org/packages/6a/22/5f0093d556403e04b6fd0984fc0fb32fbb6f6ce116828fd54306a946f444/ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b", size = 11263119, upload-time = "2025-07-11T13:21:01.503Z" },
+ { url = "https://files.pythonhosted.org/packages/92/c9/f4c0b69bdaffb9968ba40dd5fa7df354ae0c73d01f988601d8fac0c639b1/ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f", size = 11746990, upload-time = "2025-07-11T13:21:04.524Z" },
+ { url = "https://files.pythonhosted.org/packages/fe/84/7cc7bd73924ee6be4724be0db5414a4a2ed82d06b30827342315a1be9e9c/ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d", size = 10589263, upload-time = "2025-07-11T13:21:07.148Z" },
+ { url = "https://files.pythonhosted.org/packages/07/87/c070f5f027bd81f3efee7d14cb4d84067ecf67a3a8efb43aadfc72aa79a6/ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7", size = 11695072, upload-time = "2025-07-11T13:21:11.004Z" },
+ { url = "https://files.pythonhosted.org/packages/e0/30/f3eaf6563c637b6e66238ed6535f6775480db973c836336e4122161986fc/ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1", size = 10805855, upload-time = "2025-07-11T13:21:13.547Z" },
]
[[package]]
diff --git a/docker/.env.example b/docker/.env.example
index 84b6152f0a..a05141569b 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -214,6 +214,10 @@ SQLALCHEMY_POOL_SIZE=30
SQLALCHEMY_POOL_RECYCLE=3600
# Whether to print SQL, default is false.
SQLALCHEMY_ECHO=false
+# If True, will test connections for liveness upon each checkout
+SQLALCHEMY_POOL_PRE_PING=false
+# Whether to enable the Last in first out option or use default FIFO queue if is false
+SQLALCHEMY_POOL_USE_LIFO=false
# Maximum number of connections to the database
# Default is 100
@@ -285,6 +289,7 @@ REDIS_CLUSTERS_PASSWORD=
# If use Redis Sentinel, format as follows: `sentinel://:@:/`
# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1
CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1
+CELERY_BACKEND=redis
BROKER_USE_SSL=false
# If you are using Redis Sentinel for high availability, configure the following settings.
@@ -768,6 +773,8 @@ INVITE_EXPIRY_HOURS=72
# Reset password token valid time (minutes),
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
+CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
+OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
# The sandbox service endpoint.
CODE_EXECUTION_ENDPOINT=http://sandbox:8194
@@ -799,6 +806,19 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
+# Repository configuration
+# Core workflow execution repository implementation
+CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
+
+# Core workflow node execution repository implementation
+CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
+
+# API workflow node execution repository implementation
+API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
+
+# API workflow run repository implementation
+API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
+
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
@@ -1122,6 +1142,8 @@ PLUGIN_VOLCENGINE_TOS_REGION=
# OTLP Collector Configuration
# ------------------------------
ENABLE_OTEL=false
+OTLP_TRACE_ENDPOINT=
+OTLP_METRIC_ENDPOINT=
OTLP_BASE_ENDPOINT=http://localhost:4318
OTLP_API_KEY=
OTEL_EXPORTER_OTLP_PROTOCOL=
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index ac9953aa33..5962adb079 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -56,6 +56,8 @@ x-shared-env: &shared-api-worker-env
SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30}
SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600}
SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false}
+ SQLALCHEMY_POOL_PRE_PING: ${SQLALCHEMY_POOL_PRE_PING:-false}
+ SQLALCHEMY_POOL_USE_LIFO: ${SQLALCHEMY_POOL_USE_LIFO:-false}
POSTGRES_MAX_CONNECTIONS: ${POSTGRES_MAX_CONNECTIONS:-100}
POSTGRES_SHARED_BUFFERS: ${POSTGRES_SHARED_BUFFERS:-128MB}
POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB}
@@ -77,6 +79,7 @@ x-shared-env: &shared-api-worker-env
REDIS_CLUSTERS: ${REDIS_CLUSTERS:-}
REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-}
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
+ CELERY_BACKEND: ${CELERY_BACKEND:-redis}
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false}
CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-}
@@ -333,6 +336,8 @@ x-shared-env: &shared-api-worker-env
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000}
INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72}
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5}
+ CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: ${CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES:-5}
+ OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: ${OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES:-5}
CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194}
CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox}
CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807}
@@ -354,6 +359,10 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
+ CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository}
+ CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository}
+ API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
+ API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
@@ -500,6 +509,8 @@ x-shared-env: &shared-api-worker-env
PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ENABLE_OTEL: ${ENABLE_OTEL:-false}
+ OTLP_TRACE_ENDPOINT: ${OTLP_TRACE_ENDPOINT:-}
+ OTLP_METRIC_ENDPOINT: ${OTLP_METRIC_ENDPOINT:-}
OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318}
OTLP_API_KEY: ${OTLP_API_KEY:-}
OTEL_EXPORTER_OTLP_PROTOCOL: ${OTEL_EXPORTER_OTLP_PROTOCOL:-}
diff --git a/sdks/python-client/README.md b/sdks/python-client/README.md
index 8949ef08fa..7401fd2fd4 100644
--- a/sdks/python-client/README.md
+++ b/sdks/python-client/README.md
@@ -183,3 +183,42 @@ rename_conversation_response.raise_for_status()
print('[rename result]')
print(rename_conversation_response.json())
```
+
+* Using the Workflow Client
+```python
+import json
+import requests
+from dify_client import WorkflowClient
+
+api_key = "your_api_key"
+
+# Initialize Workflow Client
+client = WorkflowClient(api_key)
+
+# Prepare parameters for Workflow Client
+user_id = "your_user_id"
+context = "previous user interaction / metadata"
+user_prompt = "What is the capital of France?"
+
+inputs = {
+ "context": context,
+ "user_prompt": user_prompt,
+ # Add other input fields expected by your workflow (e.g., additional context, task parameters)
+
+}
+
+# Set response mode (default: streaming)
+response_mode = "blocking"
+
+# Run the workflow
+response = client.run(inputs=inputs, response_mode=response_mode, user=user_id)
+response.raise_for_status()
+
+# Parse result
+result = json.loads(response.text)
+
+answer = result.get("data").get("outputs")
+
+print(answer["answer"])
+
+```
diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py
index 6fa9d190e5..b557a9ce95 100644
--- a/sdks/python-client/dify_client/__init__.py
+++ b/sdks/python-client/dify_client/__init__.py
@@ -1 +1 @@
-from dify_client.client import ChatClient, CompletionClient, DifyClient
+from dify_client.client import ChatClient, CompletionClient, WorkflowClient, KnowledgeBaseClient, DifyClient
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx
index d5df70f004..15da0bbed2 100644
--- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx
+++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx
@@ -1,5 +1,3 @@
-'use client'
-
import WorkflowApp from '@/app/components/workflow-app'
const Page = () => {
diff --git a/web/app/(commonLayout)/apps/assets/add.svg b/web/app/(commonLayout)/apps/assets/add.svg
deleted file mode 100644
index 9958e855aa..0000000000
--- a/web/app/(commonLayout)/apps/assets/add.svg
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
diff --git a/web/app/(commonLayout)/apps/assets/chat-solid.svg b/web/app/(commonLayout)/apps/assets/chat-solid.svg
deleted file mode 100644
index a793e982c0..0000000000
--- a/web/app/(commonLayout)/apps/assets/chat-solid.svg
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
diff --git a/web/app/(commonLayout)/apps/assets/chat.svg b/web/app/(commonLayout)/apps/assets/chat.svg
deleted file mode 100644
index 0971349a53..0000000000
--- a/web/app/(commonLayout)/apps/assets/chat.svg
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
diff --git a/web/app/(commonLayout)/apps/assets/completion-solid.svg b/web/app/(commonLayout)/apps/assets/completion-solid.svg
deleted file mode 100644
index a9dc7e3dc1..0000000000
--- a/web/app/(commonLayout)/apps/assets/completion-solid.svg
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
diff --git a/web/app/(commonLayout)/apps/assets/completion.svg b/web/app/(commonLayout)/apps/assets/completion.svg
deleted file mode 100644
index 34af4417fe..0000000000
--- a/web/app/(commonLayout)/apps/assets/completion.svg
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
diff --git a/web/app/(commonLayout)/apps/assets/discord.svg b/web/app/(commonLayout)/apps/assets/discord.svg
deleted file mode 100644
index 9f22a1ab59..0000000000
--- a/web/app/(commonLayout)/apps/assets/discord.svg
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
diff --git a/web/app/(commonLayout)/apps/assets/github.svg b/web/app/(commonLayout)/apps/assets/github.svg
deleted file mode 100644
index f03798b5e1..0000000000
--- a/web/app/(commonLayout)/apps/assets/github.svg
+++ /dev/null
@@ -1,17 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/web/app/(commonLayout)/apps/assets/link-gray.svg b/web/app/(commonLayout)/apps/assets/link-gray.svg
deleted file mode 100644
index a293cfcf53..0000000000
--- a/web/app/(commonLayout)/apps/assets/link-gray.svg
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
diff --git a/web/app/(commonLayout)/apps/assets/link.svg b/web/app/(commonLayout)/apps/assets/link.svg
deleted file mode 100644
index 2926c28b16..0000000000
--- a/web/app/(commonLayout)/apps/assets/link.svg
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
diff --git a/web/app/(commonLayout)/apps/assets/right-arrow.svg b/web/app/(commonLayout)/apps/assets/right-arrow.svg
deleted file mode 100644
index a2c1cedf95..0000000000
--- a/web/app/(commonLayout)/apps/assets/right-arrow.svg
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
diff --git a/web/app/(commonLayout)/apps/layout.tsx b/web/app/(commonLayout)/apps/layout.tsx
deleted file mode 100644
index 10d04a4188..0000000000
--- a/web/app/(commonLayout)/apps/layout.tsx
+++ /dev/null
@@ -1,12 +0,0 @@
-'use client'
-
-import useDocumentTitle from '@/hooks/use-document-title'
-import { useTranslation } from 'react-i18next'
-
-export default function DatasetsLayout({ children }: { children: React.ReactNode }) {
- const { t } = useTranslation()
- useDocumentTitle(t('common.menus.apps'))
- return (<>
- {children}
- >)
-}
diff --git a/web/app/(commonLayout)/apps/page.tsx b/web/app/(commonLayout)/apps/page.tsx
index 3f617d41c9..25b6d55d11 100644
--- a/web/app/(commonLayout)/apps/page.tsx
+++ b/web/app/(commonLayout)/apps/page.tsx
@@ -1,32 +1,8 @@
-'use client'
-import { useTranslation } from 'react-i18next'
-import { RiDiscordFill, RiGithubFill } from '@remixicon/react'
-import Link from 'next/link'
-import style from '../list.module.css'
-import Apps from './Apps'
-import { useEducationInit } from '@/app/education-apply/hooks'
-import { useGlobalPublicStore } from '@/context/global-public-context'
+import Apps from '@/app/components/apps'
const AppList = () => {
- const { t } = useTranslation()
- useEducationInit()
- const { systemFeatures } = useGlobalPublicStore()
return (
-
-
- {!systemFeatures.branding.enabled &&
}
-
+
)
}
diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx
index acaae3f720..426778c835 100644
--- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx
+++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx
@@ -62,7 +62,6 @@ const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => {
{
diff --git a/web/app/(commonLayout)/explore/installed/[appId]/page.tsx b/web/app/(commonLayout)/explore/installed/[appId]/page.tsx
index 938a03992b..e288c62b5d 100644
--- a/web/app/(commonLayout)/explore/installed/[appId]/page.tsx
+++ b/web/app/(commonLayout)/explore/installed/[appId]/page.tsx
@@ -1,16 +1,18 @@
-import type { FC } from 'react'
import React from 'react'
import Main from '@/app/components/explore/installed-app'
export type IInstalledAppProps = {
- params: Promise<{
+ params: {
appId: string
- }>
+ }
}
-const InstalledApp: FC
= async ({ params }) => {
+// Using Next.js page convention for async server components
+async function InstalledApp({ params }: IInstalledAppProps) {
+ const appId = (await params).appId
return (
-
+
)
}
-export default React.memo(InstalledApp)
+
+export default InstalledApp
diff --git a/web/app/(commonLayout)/list.module.css b/web/app/(commonLayout)/list.module.css
deleted file mode 100644
index c4d3aec29f..0000000000
--- a/web/app/(commonLayout)/list.module.css
+++ /dev/null
@@ -1,217 +0,0 @@
-.listItem {
- @apply col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-xs min-h-[160px] flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg;
-}
-
-.listItem.newItemCard {
- @apply outline outline-1 outline-gray-200 -outline-offset-1 hover:shadow-sm hover:bg-white;
- background-color: rgba(229, 231, 235, 0.5);
-}
-
-.listItem.selectable {
- @apply relative bg-gray-50 outline outline-1 outline-gray-200 -outline-offset-1 shadow-none hover:bg-none hover:shadow-none hover:outline-primary-200 transition-colors;
-}
-
-.listItem.selectable * {
- @apply relative;
-}
-
-.listItem.selectable::before {
- content: "";
- @apply absolute top-0 left-0 block w-full h-full rounded-lg pointer-events-none opacity-0 transition-opacity duration-200 ease-in-out hover:opacity-100;
- background: linear-gradient(0deg,
- rgba(235, 245, 255, 0.5),
- rgba(235, 245, 255, 0.5)),
- #ffffff;
-}
-
-.listItem.selectable:hover::before {
- @apply opacity-100;
-}
-
-.listItem.selected {
- @apply border-primary-600 hover:border-primary-600 border-2;
-}
-
-.listItem.selected::before {
- @apply opacity-100;
-}
-
-.appIcon {
- @apply flex items-center justify-center w-8 h-8 bg-pink-100 rounded-lg grow-0 shrink-0;
-}
-
-.appIcon.medium {
- @apply w-9 h-9;
-}
-
-.appIcon.large {
- @apply w-10 h-10;
-}
-
-.newItemIcon {
- @apply flex items-center justify-center w-8 h-8 transition-colors duration-200 ease-in-out border border-gray-200 rounded-lg hover:bg-white grow-0 shrink-0;
-}
-
-.listItem:hover .newItemIcon {
- @apply bg-gray-50 border-primary-100;
-}
-
-.newItemCard .newItemIcon {
- @apply bg-gray-100;
-}
-
-.newItemCard:hover .newItemIcon {
- @apply bg-white;
-}
-
-.selectable .newItemIcon {
- @apply bg-gray-50;
-}
-
-.selectable:hover .newItemIcon {
- @apply bg-primary-50;
-}
-
-.newItemIconImage {
- @apply grow-0 shrink-0 block w-4 h-4 bg-center bg-contain transition-colors duration-200 ease-in-out;
- color: #1f2a37;
-}
-
-.listItem:hover .newIconImage {
- @apply text-primary-600;
-}
-
-.newItemIconAdd {
- background-image: url("./apps/assets/add.svg");
-}
-
-/* .newItemIconChat {
- background-image: url("~@/app/components/base/icons/assets/public/header-nav/studio/Robot.svg");
-}
-
-.selected .newItemIconChat {
- background-image: url("~@/app/components/base/icons/assets/public/header-nav/studio/Robot-Active.svg");
-} */
-
-.newItemIconComplete {
- background-image: url("./apps/assets/completion.svg");
-}
-
-.listItemTitle {
- @apply flex pt-[14px] px-[14px] pb-3 h-[66px] items-center gap-3 grow-0 shrink-0;
-}
-
-.listItemHeading {
- @apply relative h-8 text-sm font-medium leading-8 grow;
-}
-
-.listItemHeadingContent {
- @apply absolute top-0 left-0 w-full h-full overflow-hidden text-ellipsis whitespace-nowrap;
-}
-
-.actionIconWrapper {
- @apply hidden h-8 w-8 p-2 rounded-md border-none hover:bg-gray-100 !important;
-}
-
-.listItem:hover .actionIconWrapper {
- @apply !inline-flex;
-}
-
-.deleteDatasetIcon {
- @apply hidden grow-0 shrink-0 basis-8 w-8 h-8 rounded-lg transition-colors duration-200 ease-in-out bg-white border border-gray-200 hover:bg-gray-100 bg-center bg-no-repeat;
- background-size: 16px;
- background-image: url('~@/assets/delete.svg');
-}
-
-.listItem:hover .deleteDatasetIcon {
- @apply block;
-}
-
-.listItemDescription {
- @apply mb-3 px-[14px] h-9 text-xs leading-normal text-gray-500 line-clamp-2;
-}
-
-.listItemDescription.noClip {
- @apply line-clamp-none;
-}
-
-.listItemFooter {
- @apply flex items-center flex-wrap min-h-[42px] px-[14px] pt-2 pb-[10px];
-}
-
-.listItemFooter.datasetCardFooter {
- @apply flex items-center gap-4 text-xs text-gray-500;
-}
-
-.listItemStats {
- @apply flex items-center gap-1;
-}
-
-.listItemFooterIcon {
- @apply block w-3 h-3 bg-center bg-contain;
-}
-
-.solidChatIcon {
- background-image: url("./apps/assets/chat-solid.svg");
-}
-
-.solidCompletionIcon {
- background-image: url("./apps/assets/completion-solid.svg");
-}
-
-.newItemCardHeading {
- @apply transition-colors duration-200 ease-in-out;
-}
-
-.listItem:hover .newItemCardHeading {
- @apply text-primary-600;
-}
-
-.listItemLink {
- @apply inline-flex items-center gap-1 text-xs text-gray-400 transition-colors duration-200 ease-in-out;
-}
-
-.listItem:hover .listItemLink {
- @apply text-primary-600;
-}
-
-.linkIcon {
- @apply block w-[13px] h-[13px] bg-center bg-contain;
- background-image: url("./apps/assets/link.svg");
-}
-
-.linkIcon.grayLinkIcon {
- background-image: url("./apps/assets/link-gray.svg");
-}
-
-.listItem:hover .grayLinkIcon {
- background-image: url("./apps/assets/link.svg");
-}
-
-.rightIcon {
- @apply block w-[13px] h-[13px] bg-center bg-contain;
- background-image: url("./apps/assets/right-arrow.svg");
-}
-
-.socialMediaLink {
- @apply flex items-center justify-center w-8 h-8 cursor-pointer hover:opacity-80 transition-opacity duration-200 ease-in-out;
-}
-
-.socialMediaIcon {
- @apply block w-6 h-6 bg-center bg-contain;
-}
-
-/* #region new app dialog */
-.newItemCaption {
- @apply inline-flex items-center mb-2 text-sm font-medium;
-}
-
-/* #endregion new app dialog */
-
-.unavailable {
- @apply opacity-50;
-}
-
-.listItem:hover .unavailable {
- @apply opacity-100;
-}
diff --git a/web/app/(shareLayout)/chat/[token]/page.tsx b/web/app/(shareLayout)/chat/[token]/page.tsx
index 640c40378f..8ce67585f0 100644
--- a/web/app/(shareLayout)/chat/[token]/page.tsx
+++ b/web/app/(shareLayout)/chat/[token]/page.tsx
@@ -1,10 +1,13 @@
'use client'
import React from 'react'
import ChatWithHistoryWrap from '@/app/components/base/chat/chat-with-history'
+import AuthenticatedLayout from '../../components/authenticated-layout'
const Chat = () => {
return (
-
+
+
+
)
}
diff --git a/web/app/(shareLayout)/chatbot/[token]/page.tsx b/web/app/(shareLayout)/chatbot/[token]/page.tsx
index 6196afecc4..5323d0dacc 100644
--- a/web/app/(shareLayout)/chatbot/[token]/page.tsx
+++ b/web/app/(shareLayout)/chatbot/[token]/page.tsx
@@ -1,10 +1,13 @@
'use client'
import React from 'react'
import EmbeddedChatbot from '@/app/components/base/chat/embedded-chatbot'
+import AuthenticatedLayout from '../../components/authenticated-layout'
const Chatbot = () => {
return (
-
+
+
+
)
}
diff --git a/web/app/(shareLayout)/completion/[token]/page.tsx b/web/app/(shareLayout)/completion/[token]/page.tsx
index e8bc9d79f5..ae91338b9a 100644
--- a/web/app/(shareLayout)/completion/[token]/page.tsx
+++ b/web/app/(shareLayout)/completion/[token]/page.tsx
@@ -1,9 +1,12 @@
import React from 'react'
import Main from '@/app/components/share/text-generation'
+import AuthenticatedLayout from '../../components/authenticated-layout'
const Completion = () => {
return (
-
+
+
+
)
}
diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx
new file mode 100644
index 0000000000..e3cfc8e6a8
--- /dev/null
+++ b/web/app/(shareLayout)/components/authenticated-layout.tsx
@@ -0,0 +1,84 @@
+'use client'
+
+import AppUnavailable from '@/app/components/base/app-unavailable'
+import Loading from '@/app/components/base/loading'
+import { removeAccessToken } from '@/app/components/share/utils'
+import { useWebAppStore } from '@/context/web-app-context'
+import { useGetUserCanAccessApp } from '@/service/access-control'
+import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share'
+import { usePathname, useRouter, useSearchParams } from 'next/navigation'
+import React, { useCallback, useEffect } from 'react'
+import { useTranslation } from 'react-i18next'
+
+const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
+ const { t } = useTranslation()
+ const updateAppInfo = useWebAppStore(s => s.updateAppInfo)
+ const updateAppParams = useWebAppStore(s => s.updateAppParams)
+ const updateWebAppMeta = useWebAppStore(s => s.updateWebAppMeta)
+ const updateUserCanAccessApp = useWebAppStore(s => s.updateUserCanAccessApp)
+ const { isFetching: isFetchingAppParams, data: appParams, error: appParamsError } = useGetWebAppParams()
+ const { isFetching: isFetchingAppInfo, data: appInfo, error: appInfoError } = useGetWebAppInfo()
+ const { isFetching: isFetchingAppMeta, data: appMeta, error: appMetaError } = useGetWebAppMeta()
+ const { data: userCanAccessApp, error: useCanAccessAppError } = useGetUserCanAccessApp({ appId: appInfo?.app_id, isInstalledApp: false })
+
+ useEffect(() => {
+ if (appInfo)
+ updateAppInfo(appInfo)
+ if (appParams)
+ updateAppParams(appParams)
+ if (appMeta)
+ updateWebAppMeta(appMeta)
+ updateUserCanAccessApp(Boolean(userCanAccessApp && userCanAccessApp?.result))
+ }, [appInfo, appMeta, appParams, updateAppInfo, updateAppParams, updateUserCanAccessApp, updateWebAppMeta, userCanAccessApp])
+
+ const router = useRouter()
+ const pathname = usePathname()
+ const searchParams = useSearchParams()
+ const getSigninUrl = useCallback(() => {
+ const params = new URLSearchParams(searchParams)
+ params.delete('message')
+ params.set('redirect_url', pathname)
+ return `/webapp-signin?${params.toString()}`
+ }, [searchParams, pathname])
+
+ const backToHome = useCallback(() => {
+ removeAccessToken()
+ const url = getSigninUrl()
+ router.replace(url)
+ }, [getSigninUrl, router])
+
+ if (appInfoError) {
+ return
+ }
+ if (appParamsError) {
+ return
+ }
+ if (appMetaError) {
+ return
+ }
+ if (useCanAccessAppError) {
+ return
+ }
+ if (userCanAccessApp && !userCanAccessApp.result) {
+ return
+
+
{t('common.userProfile.logout')}
+
+ }
+ if (isFetchingAppInfo || isFetchingAppParams || isFetchingAppMeta) {
+ return
+
+
+ }
+ return <>{children}>
+}
+
+export default React.memo(AuthenticatedLayout)
diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx
new file mode 100644
index 0000000000..4fe9efe4dd
--- /dev/null
+++ b/web/app/(shareLayout)/components/splash.tsx
@@ -0,0 +1,80 @@
+'use client'
+import type { FC, PropsWithChildren } from 'react'
+import { useEffect } from 'react'
+import { useCallback } from 'react'
+import { useWebAppStore } from '@/context/web-app-context'
+import { useRouter, useSearchParams } from 'next/navigation'
+import AppUnavailable from '@/app/components/base/app-unavailable'
+import { checkOrSetAccessToken, removeAccessToken, setAccessToken } from '@/app/components/share/utils'
+import { useTranslation } from 'react-i18next'
+import { fetchAccessToken } from '@/service/share'
+import Loading from '@/app/components/base/loading'
+import { AccessMode } from '@/models/access-control'
+
+const Splash: FC = ({ children }) => {
+ const { t } = useTranslation()
+ const shareCode = useWebAppStore(s => s.shareCode)
+ const webAppAccessMode = useWebAppStore(s => s.webAppAccessMode)
+ const searchParams = useSearchParams()
+ const router = useRouter()
+ const redirectUrl = searchParams.get('redirect_url')
+ const tokenFromUrl = searchParams.get('web_sso_token')
+ const message = searchParams.get('message')
+ const code = searchParams.get('code')
+ const getSigninUrl = useCallback(() => {
+ const params = new URLSearchParams(searchParams)
+ params.delete('message')
+ params.delete('code')
+ return `/webapp-signin?${params.toString()}`
+ }, [searchParams])
+
+ const backToHome = useCallback(() => {
+ removeAccessToken()
+ const url = getSigninUrl()
+ router.replace(url)
+ }, [getSigninUrl, router])
+
+ useEffect(() => {
+ (async () => {
+ if (message)
+ return
+ if (shareCode && tokenFromUrl && redirectUrl) {
+ localStorage.setItem('webapp_access_token', tokenFromUrl)
+ const tokenResp = await fetchAccessToken({ appCode: shareCode, webAppAccessToken: tokenFromUrl })
+ await setAccessToken(shareCode, tokenResp.access_token)
+ router.replace(decodeURIComponent(redirectUrl))
+ return
+ }
+ if (shareCode && redirectUrl && localStorage.getItem('webapp_access_token')) {
+ const tokenResp = await fetchAccessToken({ appCode: shareCode, webAppAccessToken: localStorage.getItem('webapp_access_token') })
+ await setAccessToken(shareCode, tokenResp.access_token)
+ router.replace(decodeURIComponent(redirectUrl))
+ return
+ }
+ if (webAppAccessMode === AccessMode.PUBLIC && redirectUrl) {
+ await checkOrSetAccessToken(shareCode)
+ router.replace(decodeURIComponent(redirectUrl))
+ }
+ })()
+ }, [shareCode, redirectUrl, router, tokenFromUrl, message, webAppAccessMode])
+
+ if (message) {
+ return
+
+
{code === '403' ? t('common.userProfile.logout') : t('share.login.backToHome')}
+
+ }
+ if (tokenFromUrl) {
+ return
+
+
+ }
+ if (webAppAccessMode === AccessMode.PUBLIC && redirectUrl) {
+ return
+
+
+ }
+ return <>{children}>
+}
+
+export default Splash
diff --git a/web/app/(shareLayout)/layout.tsx b/web/app/(shareLayout)/layout.tsx
index d057ba7599..5af913cac9 100644
--- a/web/app/(shareLayout)/layout.tsx
+++ b/web/app/(shareLayout)/layout.tsx
@@ -1,54 +1,15 @@
-'use client'
-import React, { useEffect, useState } from 'react'
-import type { FC } from 'react'
-import { usePathname, useSearchParams } from 'next/navigation'
-import Loading from '../components/base/loading'
-import { useGlobalPublicStore } from '@/context/global-public-context'
-import { AccessMode } from '@/models/access-control'
-import { getAppAccessModeByAppCode } from '@/service/share'
+import type { FC, PropsWithChildren } from 'react'
+import WebAppStoreProvider from '@/context/web-app-context'
+import Splash from './components/splash'
-const Layout: FC<{
- children: React.ReactNode
-}> = ({ children }) => {
- const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending)
- const setWebAppAccessMode = useGlobalPublicStore(s => s.setWebAppAccessMode)
- const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
- const pathname = usePathname()
- const searchParams = useSearchParams()
- const redirectUrl = searchParams.get('redirect_url')
- const [isLoading, setIsLoading] = useState(true)
- useEffect(() => {
- (async () => {
- if (!isGlobalPending && !systemFeatures.webapp_auth.enabled) {
- setIsLoading(false)
- return
- }
-
- let appCode: string | null = null
- if (redirectUrl) {
- const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
- appCode = url.pathname.split('/').pop() || null
- }
- else {
- appCode = pathname.split('/').pop() || null
- }
-
- if (!appCode)
- return
- setIsLoading(true)
- const ret = await getAppAccessModeByAppCode(appCode)
- setWebAppAccessMode(ret?.accessMode || AccessMode.PUBLIC)
- setIsLoading(false)
- })()
- }, [pathname, redirectUrl, setWebAppAccessMode, isGlobalPending, systemFeatures.webapp_auth.enabled])
- if (isLoading || isGlobalPending) {
- return
-
-
- }
+const Layout: FC = ({ children }) => {
return (
- {children}
+
+
+ {children}
+
+
)
}
diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx
index 9f9a8ad4e3..5e3f6fff1d 100644
--- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx
+++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx
@@ -9,8 +9,7 @@ import Button from '@/app/components/base/button'
import { changeWebAppPasswordWithToken } from '@/service/common'
import Toast from '@/app/components/base/toast'
import Input from '@/app/components/base/input'
-
-const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/
+import { validPassword } from '@/config'
const ChangePasswordForm = () => {
const { t } = useTranslation()
diff --git a/web/app/(shareLayout)/webapp-signin/layout.tsx b/web/app/(shareLayout)/webapp-signin/layout.tsx
index a03364d326..7649982072 100644
--- a/web/app/(shareLayout)/webapp-signin/layout.tsx
+++ b/web/app/(shareLayout)/webapp-signin/layout.tsx
@@ -3,10 +3,13 @@
import cn from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
import useDocumentTitle from '@/hooks/use-document-title'
+import type { PropsWithChildren } from 'react'
+import { useTranslation } from 'react-i18next'
-export default function SignInLayout({ children }: any) {
- const { systemFeatures } = useGlobalPublicStore()
- useDocumentTitle('')
+export default function SignInLayout({ children }: PropsWithChildren) {
+ const { t } = useTranslation()
+ const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
+ useDocumentTitle(t('login.webapp.login'))
return <>
diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx
index d6bdf607ba..44006a9f1e 100644
--- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx
+++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx
@@ -1,3 +1,4 @@
+'use client'
import React, { useCallback, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Link from 'next/link'
diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx
index 967516c416..1c6209b902 100644
--- a/web/app/(shareLayout)/webapp-signin/page.tsx
+++ b/web/app/(shareLayout)/webapp-signin/page.tsx
@@ -1,36 +1,30 @@
'use client'
import { useRouter, useSearchParams } from 'next/navigation'
import type { FC } from 'react'
-import React, { useCallback, useEffect } from 'react'
+import React, { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
-import Toast from '@/app/components/base/toast'
-import { removeAccessToken, setAccessToken } from '@/app/components/share/utils'
+import { removeAccessToken } from '@/app/components/share/utils'
import { useGlobalPublicStore } from '@/context/global-public-context'
-import Loading from '@/app/components/base/loading'
import AppUnavailable from '@/app/components/base/app-unavailable'
import NormalForm from './normalForm'
import { AccessMode } from '@/models/access-control'
import ExternalMemberSsoAuth from './components/external-member-sso-auth'
-import { fetchAccessToken } from '@/service/share'
+import { useWebAppStore } from '@/context/web-app-context'
const WebSSOForm: FC = () => {
const { t } = useTranslation()
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
- const webAppAccessMode = useGlobalPublicStore(s => s.webAppAccessMode)
+ const webAppAccessMode = useWebAppStore(s => s.webAppAccessMode)
const searchParams = useSearchParams()
const router = useRouter()
const redirectUrl = searchParams.get('redirect_url')
- const tokenFromUrl = searchParams.get('web_sso_token')
- const message = searchParams.get('message')
- const code = searchParams.get('code')
const getSigninUrl = useCallback(() => {
- const params = new URLSearchParams(searchParams)
- params.delete('message')
- params.delete('code')
+ const params = new URLSearchParams()
+ params.append('redirect_url', redirectUrl || '')
return `/webapp-signin?${params.toString()}`
- }, [searchParams])
+ }, [redirectUrl])
const backToHome = useCallback(() => {
removeAccessToken()
@@ -38,73 +32,12 @@ const WebSSOForm: FC = () => {
router.replace(url)
}, [getSigninUrl, router])
- const showErrorToast = (msg: string) => {
- Toast.notify({
- type: 'error',
- message: msg,
- })
- }
-
- const getAppCodeFromRedirectUrl = useCallback(() => {
- if (!redirectUrl)
- return null
- const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`)
- const appCode = url.pathname.split('/').pop()
- if (!appCode)
- return null
-
- return appCode
- }, [redirectUrl])
-
- useEffect(() => {
- (async () => {
- if (message)
- return
-
- const appCode = getAppCodeFromRedirectUrl()
- if (appCode && tokenFromUrl && redirectUrl) {
- localStorage.setItem('webapp_access_token', tokenFromUrl)
- const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: tokenFromUrl })
- await setAccessToken(appCode, tokenResp.access_token)
- router.replace(decodeURIComponent(redirectUrl))
- return
- }
- if (appCode && redirectUrl && localStorage.getItem('webapp_access_token')) {
- const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: localStorage.getItem('webapp_access_token') })
- await setAccessToken(appCode, tokenResp.access_token)
- router.replace(decodeURIComponent(redirectUrl))
- }
- })()
- }, [getAppCodeFromRedirectUrl, redirectUrl, router, tokenFromUrl, message])
-
- useEffect(() => {
- if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC && redirectUrl)
- router.replace(decodeURIComponent(redirectUrl))
- }, [webAppAccessMode, router, redirectUrl])
-
- if (tokenFromUrl) {
- return
-
-
- }
-
- if (message) {
- return
-
-
{code === '403' ? t('common.userProfile.logout') : t('share.login.backToHome')}
-
- }
if (!redirectUrl) {
- showErrorToast('redirect url is invalid.')
return
}
- if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC) {
- return
-
-
- }
+
if (!systemFeatures.webapp_auth.enabled) {
return
{t('login.webapp.disabled')}
diff --git a/web/app/(shareLayout)/workflow/[token]/page.tsx b/web/app/(shareLayout)/workflow/[token]/page.tsx
index e93bc8c1af..4f5923e91f 100644
--- a/web/app/(shareLayout)/workflow/[token]/page.tsx
+++ b/web/app/(shareLayout)/workflow/[token]/page.tsx
@@ -1,10 +1,13 @@
import React from 'react'
import Main from '@/app/components/share/text-generation'
+import AuthenticatedLayout from '../../components/authenticated-layout'
const Workflow = () => {
return (
-
+
+
+
)
}
diff --git a/web/app/account/account-page/email-change-modal.tsx b/web/app/account/account-page/email-change-modal.tsx
new file mode 100644
index 0000000000..c3efad104a
--- /dev/null
+++ b/web/app/account/account-page/email-change-modal.tsx
@@ -0,0 +1,371 @@
+import React, { useState } from 'react'
+import { Trans, useTranslation } from 'react-i18next'
+import { useRouter } from 'next/navigation'
+import { useContext } from 'use-context-selector'
+import { ToastContext } from '@/app/components/base/toast'
+import { RiCloseLine } from '@remixicon/react'
+import Modal from '@/app/components/base/modal'
+import Button from '@/app/components/base/button'
+import Input from '@/app/components/base/input'
+import {
+ checkEmailExisted,
+ logout,
+ resetEmail,
+ sendVerifyCode,
+ verifyEmail,
+} from '@/service/common'
+import { noop } from 'lodash-es'
+
+type Props = {
+ show: boolean
+ onClose: () => void
+ email: string
+}
+
+enum STEP {
+ start = 'start',
+ verifyOrigin = 'verifyOrigin',
+ newEmail = 'newEmail',
+ verifyNew = 'verifyNew',
+}
+
+const EmailChangeModal = ({ onClose, email, show }: Props) => {
+ const { t } = useTranslation()
+ const { notify } = useContext(ToastContext)
+ const router = useRouter()
+ const [step, setStep] = useState
(STEP.start)
+ const [code, setCode] = useState('')
+ const [mail, setMail] = useState('')
+ const [time, setTime] = useState(0)
+ const [stepToken, setStepToken] = useState('')
+ const [newEmailExited, setNewEmailExited] = useState(false)
+ const [isCheckingEmail, setIsCheckingEmail] = useState(false)
+
+ const startCount = () => {
+ setTime(60)
+ const timer = setInterval(() => {
+ setTime((prev) => {
+ if (prev <= 0) {
+ clearInterval(timer)
+ return 0
+ }
+ return prev - 1
+ })
+ }, 1000)
+ }
+
+ const sendEmail = async (email: string, isOrigin: boolean, token?: string) => {
+ try {
+ const res = await sendVerifyCode({
+ email,
+ phase: isOrigin ? 'old_email' : 'new_email',
+ token,
+ })
+ startCount()
+ if (res.data)
+ setStepToken(res.data)
+ }
+ catch (error) {
+ notify({
+ type: 'error',
+ message: `Error sending verification code: ${error ? (error as any).message : ''}`,
+ })
+ }
+ }
+
+ const verifyEmailAddress = async (email: string, code: string, token: string, callback?: (data?: any) => void) => {
+ try {
+ const res = await verifyEmail({
+ email,
+ code,
+ token,
+ })
+ if (res.is_valid) {
+ setStepToken(res.token)
+ callback?.(res.token)
+ }
+ else {
+ notify({
+ type: 'error',
+ message: 'Verifying email failed',
+ })
+ }
+ }
+ catch (error) {
+ notify({
+ type: 'error',
+ message: `Error verifying email: ${error ? (error as any).message : ''}`,
+ })
+ }
+ }
+
+ const sendCodeToOriginEmail = async () => {
+ await sendEmail(
+ email,
+ true,
+ )
+ setStep(STEP.verifyOrigin)
+ }
+
+ const handleVerifyOriginEmail = async () => {
+ await verifyEmailAddress(email, code, stepToken, () => setStep(STEP.newEmail))
+ setCode('')
+ }
+
+ const isValidEmail = (email: string): boolean => {
+ const rfc5322emailRegex = /^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$/
+ return rfc5322emailRegex.test(email) && email.length <= 254
+ }
+
+ const checkNewEmailExisted = async (email: string) => {
+ setIsCheckingEmail(true)
+ try {
+ await checkEmailExisted({
+ email,
+ })
+ setNewEmailExited(false)
+ }
+ catch {
+ setNewEmailExited(true)
+ }
+ finally {
+ setIsCheckingEmail(false)
+ }
+ }
+
+ const handleNewEmailValueChange = (mailAddress: string) => {
+ setMail(mailAddress)
+ setNewEmailExited(false)
+ if (isValidEmail(mailAddress))
+ checkNewEmailExisted(mailAddress)
+ }
+
+ const sendCodeToNewEmail = async () => {
+ if (!isValidEmail(mail)) {
+ notify({
+ type: 'error',
+ message: 'Invalid email format',
+ })
+ return
+ }
+ await sendEmail(
+ mail,
+ false,
+ stepToken,
+ )
+ setStep(STEP.verifyNew)
+ }
+
+ const handleLogout = async () => {
+ await logout({
+ url: '/logout',
+ params: {},
+ })
+
+ localStorage.removeItem('setup_status')
+ localStorage.removeItem('console_token')
+ localStorage.removeItem('refresh_token')
+
+ router.push('/signin')
+ }
+
+ const updateEmail = async (lastToken: string) => {
+ try {
+ await resetEmail({
+ new_email: mail,
+ token: lastToken,
+ })
+ handleLogout()
+ }
+ catch (error) {
+ notify({
+ type: 'error',
+ message: `Error changing email: ${error ? (error as any).message : ''}`,
+ })
+ }
+ }
+
+ const submitNewEmail = async () => {
+ await verifyEmailAddress(mail, code, stepToken, updateEmail)
+ }
+
+ return (
+
+
+
+
+ {step === STEP.start && (
+ <>
+ {t('common.account.changeEmail.title')}
+
+
{t('common.account.changeEmail.authTip')}
+
+ }}
+ values={{ email }}
+ />
+
+
+
+
+
+ {t('common.account.changeEmail.sendVerifyCode')}
+
+
+ {t('common.operation.cancel')}
+
+
+ >
+ )}
+ {step === STEP.verifyOrigin && (
+ <>
+ {t('common.account.changeEmail.verifyEmail')}
+
+
+ }}
+ values={{ email }}
+ />
+
+
+
+
{t('common.account.changeEmail.codeLabel')}
+
setCode(e.target.value)}
+ maxLength={6}
+ />
+
+
+
+ {t('common.account.changeEmail.continue')}
+
+
+ {t('common.operation.cancel')}
+
+
+
+ {t('common.account.changeEmail.resendTip')}
+ {time > 0 && (
+ {t('common.account.changeEmail.resendCount', { count: time })}
+ )}
+ {!time && (
+ {t('common.account.changeEmail.resend')}
+ )}
+
+ >
+ )}
+ {step === STEP.newEmail && (
+ <>
+ {t('common.account.changeEmail.newEmail')}
+
+
{t('common.account.changeEmail.content3')}
+
+
+
{t('common.account.changeEmail.emailLabel')}
+
handleNewEmailValueChange(e.target.value)}
+ destructive={newEmailExited}
+ />
+ {newEmailExited && (
+
{t('common.account.changeEmail.existingEmail')}
+ )}
+
+
+
+ {t('common.account.changeEmail.sendVerifyCode')}
+
+
+ {t('common.operation.cancel')}
+
+
+ >
+ )}
+ {step === STEP.verifyNew && (
+ <>
+ {t('common.account.changeEmail.verifyNew')}
+
+
+ }}
+ values={{ email: mail }}
+ />
+
+
+
+
{t('common.account.changeEmail.codeLabel')}
+
setCode(e.target.value)}
+ maxLength={6}
+ />
+
+
+
+ {t('common.account.changeEmail.changeTo', { email: mail })}
+
+
+ {t('common.operation.cancel')}
+
+
+
+ {t('common.account.changeEmail.resendTip')}
+ {time > 0 && (
+ {t('common.account.changeEmail.resendCount', { count: time })}
+ )}
+ {!time && (
+ {t('common.account.changeEmail.resend')}
+ )}
+
+ >
+ )}
+
+ )
+}
+
+export default EmailChangeModal
diff --git a/web/app/account/account-page/index.module.css b/web/app/account/account-page/index.module.css
deleted file mode 100644
index 949d1257e9..0000000000
--- a/web/app/account/account-page/index.module.css
+++ /dev/null
@@ -1,9 +0,0 @@
-.modal {
- padding: 24px 32px !important;
- width: 400px !important;
-}
-
-.bg {
- background: linear-gradient(180deg, rgba(217, 45, 32, 0.05) 0%, rgba(217, 45, 32, 0.00) 24.02%), #F9FAFB;
-}
-
diff --git a/web/app/account/account-page/index.tsx b/web/app/account/account-page/index.tsx
index 19c1e44236..55fa2983dd 100644
--- a/web/app/account/account-page/index.tsx
+++ b/web/app/account/account-page/index.tsx
@@ -6,7 +6,6 @@ import {
} from '@remixicon/react'
import { useContext } from 'use-context-selector'
import DeleteAccount from '../delete-account'
-import s from './index.module.css'
import AvatarWithEdit from './AvatarWithEdit'
import Collapse from '@/app/components/header/account-setting/collapse'
import type { IItem } from '@/app/components/header/account-setting/collapse'
@@ -21,6 +20,8 @@ import { IS_CE_EDITION } from '@/config'
import Input from '@/app/components/base/input'
import PremiumBadge from '@/app/components/base/premium-badge'
import { useGlobalPublicStore } from '@/context/global-public-context'
+import EmailChangeModal from './email-change-modal'
+import { validPassword } from '@/config'
const titleClassName = `
system-sm-semibold text-text-secondary
@@ -29,8 +30,6 @@ const descriptionClassName = `
mt-1 body-xs-regular text-text-tertiary
`
-const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/
-
export default function AccountPage() {
const { t } = useTranslation()
const { systemFeatures } = useGlobalPublicStore()
@@ -48,6 +47,7 @@ export default function AccountPage() {
const [showCurrentPassword, setShowCurrentPassword] = useState(false)
const [showPassword, setShowPassword] = useState(false)
const [showConfirmPassword, setShowConfirmPassword] = useState(false)
+ const [showUpdateEmail, setShowUpdateEmail] = useState(false)
const handleEditName = () => {
setEditNameModalVisible(true)
@@ -123,10 +123,17 @@ export default function AccountPage() {
}
const renderAppItem = (item: IItem) => {
+ const { icon, icon_background, icon_type, icon_url } = item as any
return (
@@ -170,6 +177,11 @@ export default function AccountPage() {
{userProfile.email}
+ {systemFeatures.enable_change_email && (
+ setShowUpdateEmail(true)}>
+ {t('common.operation.change')}
+
+ )}
{
@@ -190,7 +202,7 @@ export default function AccountPage() {
{!!apps.length && (
({ key: app.id, name: app.name }))}
+ items={apps.map(app => ({ ...app, key: app.id, name: app.name }))}
renderItem={renderAppItem}
wrapperClassName='mt-2'
/>
@@ -202,7 +214,7 @@ export default function AccountPage() {
setEditNameModalVisible(false)}
- className={s.modal}
+ className='!w-[420px] !p-6'
>
{t('common.account.editName')}
{t('common.account.name')}
@@ -231,7 +243,7 @@ export default function AccountPage() {
setEditPasswordModalVisible(false)
resetPasswordForm()
}}
- className={s.modal}
+ className='!w-[420px] !p-6'
>
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
{userProfile.is_password_set && (
@@ -316,6 +328,13 @@ export default function AccountPage() {
/>
)
}
+ {showUpdateEmail && (
+ setShowUpdateEmail(false)}
+ email={userProfile.email}
+ />
+ )}
>
)
}
diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx
index c28cc20df5..e85eaa2f53 100644
--- a/web/app/components/app-sidebar/app-info.tsx
+++ b/web/app/components/app-sidebar/app-info.tsx
@@ -12,23 +12,17 @@ import {
RiFileUploadLine,
} from '@remixicon/react'
import AppIcon from '../base/app-icon'
-import SwitchAppModal from '../app/switch-app-modal'
import cn from '@/utils/classnames'
-import Confirm from '@/app/components/base/confirm'
import { useStore as useAppStore } from '@/app/components/app/store'
import { ToastContext } from '@/app/components/base/toast'
import AppsContext, { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps'
-import DuplicateAppModal from '@/app/components/app/duplicate-modal'
import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal'
-import CreateAppModal from '@/app/components/explore/create-app-modal'
import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { getRedirection } from '@/utils/app-redirection'
-import UpdateDSLModal from '@/app/components/workflow/update-dsl-modal'
import type { EnvironmentVariable } from '@/app/components/workflow/types'
-import DSLExportConfirmModal from '@/app/components/workflow/dsl-export-confirm-modal'
import { fetchWorkflowDraft } from '@/service/workflow'
import ContentDialog from '@/app/components/base/content-dialog'
import Button from '@/app/components/base/button'
@@ -36,6 +30,26 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie
import Divider from '../base/divider'
import type { Operation } from './app-operations'
import AppOperations from './app-operations'
+import dynamic from 'next/dynamic'
+
+const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), {
+ ssr: false,
+})
+const CreateAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), {
+ ssr: false,
+})
+const DuplicateAppModal = dynamic(() => import('@/app/components/app/duplicate-modal'), {
+ ssr: false,
+})
+const Confirm = dynamic(() => import('@/app/components/base/confirm'), {
+ ssr: false,
+})
+const UpdateDSLModal = dynamic(() => import('@/app/components/workflow/update-dsl-modal'), {
+ ssr: false,
+})
+const DSLExportConfirmModal = dynamic(() => import('@/app/components/workflow/dsl-export-confirm-modal'), {
+ ssr: false,
+})
export type IAppInfoProps = {
expand: boolean
@@ -71,6 +85,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
icon_background,
description,
use_icon_as_answer_icon,
+ max_active_requests,
}) => {
if (!appDetail)
return
@@ -83,6 +98,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
icon_background,
description,
use_icon_as_answer_icon,
+ max_active_requests,
})
setShowEditModal(false)
notify({
@@ -308,13 +324,11 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
operations={operations}
/>
-
-
-
+
setShowEditModal(false)}
diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx
index 83a7ffd553..cb98aa4950 100644
--- a/web/app/components/app/app-publisher/index.tsx
+++ b/web/app/components/app/app-publisher/index.tsx
@@ -6,6 +6,7 @@ import {
} from 'react'
import { useTranslation } from 'react-i18next'
import dayjs from 'dayjs'
+import relativeTime from 'dayjs/plugin/relativeTime'
import {
RiArrowDownSLine,
RiArrowRightSLine,
@@ -48,6 +49,7 @@ import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/acces
import { AccessMode } from '@/models/access-control'
import { fetchAppDetail } from '@/service/apps'
import { useGlobalPublicStore } from '@/context/global-public-context'
+dayjs.extend(relativeTime)
export type AppPublisherProps = {
disabled?: boolean
@@ -116,6 +118,7 @@ const AppPublisher = ({
}
}, [appAccessSubjects, appDetail])
const language = useGetLanguage()
+
const formatTimeFromNow = useCallback((time: number) => {
return dayjs(time).locale(language === 'zh_Hans' ? 'zh-cn' : language.replace('_', '-')).fromNow()
}, [language])
@@ -180,8 +183,7 @@ const AppPublisher = ({
if (publishDisabled || published)
return
handlePublish()
- },
- { exactMatch: true, useCapture: true })
+ }, { exactMatch: true, useCapture: true })
return (
<>
diff --git a/web/app/components/app/app-publisher/suggested-action.tsx b/web/app/components/app/app-publisher/suggested-action.tsx
index 8d4ab3d39c..2535de6654 100644
--- a/web/app/components/app/app-publisher/suggested-action.tsx
+++ b/web/app/components/app/app-publisher/suggested-action.tsx
@@ -20,8 +20,8 @@ const SuggestedAction = ({ icon, link, disabled, children, className, onClick, .
target='_blank'
rel='noreferrer'
className={classNames(
- 'flex justify-start items-center gap-2 py-2 px-2.5 bg-background-section-burn rounded-lg text-text-secondary transition-colors [&:not(:first-child)]:mt-1',
- disabled ? 'shadow-xs opacity-30 cursor-not-allowed' : 'text-text-secondary hover:bg-state-accent-hover hover:text-text-accent cursor-pointer',
+ 'flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1',
+ disabled ? 'cursor-not-allowed opacity-30 shadow-xs' : 'cursor-pointer text-text-secondary hover:bg-state-accent-hover hover:text-text-accent',
className,
)}
onClick={handleClick}
diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx
index 437e25fde4..e2d37bb9de 100644
--- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx
+++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx
@@ -17,8 +17,8 @@ import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap'
import cn from '@/utils/classnames'
import type { PromptRole, PromptVariable } from '@/models/debug'
import {
- Clipboard,
- ClipboardCheck,
+ Copy,
+ CopyCheck,
} from '@/app/components/base/icons/src/vender/line/files'
import Button from '@/app/components/base/button'
import Tooltip from '@/app/components/base/tooltip'
@@ -188,13 +188,13 @@ const AdvancedPromptInput: FC = ({
)}
{!isCopied
? (
- {
+ {
copy(value)
setIsCopied(true)
}} />
)
: (
-
+
)}
diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx
index 29cbc55b90..8fcc0f4c08 100644
--- a/web/app/components/app/configuration/config-var/config-modal/index.tsx
+++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx
@@ -1,5 +1,5 @@
'use client'
-import type { FC } from 'react'
+import type { ChangeEvent, FC } from 'react'
import React, { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
@@ -11,7 +11,7 @@ import SelectTypeItem from '../select-type-item'
import Field from './field'
import Input from '@/app/components/base/input'
import Toast from '@/app/components/base/toast'
-import { checkKeys, getNewVarInWorkflow } from '@/utils/var'
+import { checkKeys, getNewVarInWorkflow, replaceSpaceWithUnderscreInVarNameInput } from '@/utils/var'
import ConfigContext from '@/context/debug-configuration'
import type { InputVar, MoreInfo, UploadFileSetting } from '@/app/components/workflow/types'
import Modal from '@/app/components/base/modal'
@@ -109,6 +109,20 @@ const ConfigModal: FC = ({
})
}, [checkVariableName, tempPayload.label])
+ const handleVarNameChange = useCallback((e: ChangeEvent) => {
+ replaceSpaceWithUnderscreInVarNameInput(e.target)
+ const value = e.target.value
+ const { isValid, errorKey, errorMessageKey } = checkKeys([value], true)
+ if (!isValid) {
+ Toast.notify({
+ type: 'error',
+ message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: errorKey }),
+ })
+ return
+ }
+ handlePayloadChange('variable')(e.target.value)
+ }, [handlePayloadChange, t])
+
const handleConfirm = () => {
const moreInfo = tempPayload.variable === payload?.variable
? undefined
@@ -200,7 +214,7 @@ const ConfigModal: FC = ({
handlePayloadChange('variable')(e.target.value)}
+ onChange={handleVarNameChange}
onBlur={handleVarKeyBlur}
placeholder={t('appDebug.variableConfig.inputPlaceholder')!}
/>
diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx
index 66fe85a170..b4711ea39a 100644
--- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx
+++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx
@@ -18,7 +18,6 @@ import AppIcon from '@/app/components/base/app-icon'
import Button from '@/app/components/base/button'
import Indicator from '@/app/components/header/indicator'
import Switch from '@/app/components/base/switch'
-import Toast from '@/app/components/base/toast'
import ConfigContext from '@/context/debug-configuration'
import type { AgentTool } from '@/types/app'
import { type Collection, CollectionType } from '@/app/components/tools/types'
@@ -26,8 +25,6 @@ import { MAX_TOOLS_NUM } from '@/config'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import Tooltip from '@/app/components/base/tooltip'
import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other'
-import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials'
-import { updateBuiltInToolCredential } from '@/service/tools'
import cn from '@/utils/classnames'
import ToolPicker from '@/app/components/workflow/block-selector/tool-picker'
import type { ToolDefaultValue, ToolValue } from '@/app/components/workflow/block-selector/types'
@@ -57,13 +54,7 @@ const AgentTools: FC = () => {
const formattingChangedDispatcher = useFormattingChangedDispatcher()
const [currentTool, setCurrentTool] = useState(null)
- const currentCollection = useMemo(() => {
- if (!currentTool) return null
- const collection = collectionList.find(collection => canFindTool(collection.id, currentTool?.provider_id) && collection.type === currentTool?.provider_type)
- return collection
- }, [currentTool, collectionList])
const [isShowSettingTool, setIsShowSettingTool] = useState(false)
- const [isShowSettingAuth, setShowSettingAuth] = useState(false)
const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => {
const collection = collectionList.find(
collection =>
@@ -100,17 +91,6 @@ const AgentTools: FC = () => {
formattingChangedDispatcher()
}
- const handleToolAuthSetting = (value: AgentToolWithMoreInfo) => {
- const newModelConfig = produce(modelConfig, (draft) => {
- const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === value?.collection?.id && item.tool_name === value?.tool_name)
- if (tool)
- (tool as AgentTool).notAuthor = false
- })
- setModelConfig(newModelConfig)
- setIsShowSettingTool(false)
- formattingChangedDispatcher()
- }
-
const [isDeleting, setIsDeleting] = useState(-1)
const getToolValue = (tool: ToolDefaultValue) => {
return {
@@ -144,6 +124,20 @@ const AgentTools: FC = () => {
return item.provider_name
}
+ const handleAuthorizationItemClick = useCallback((credentialId: string) => {
+ const newModelConfig = produce(modelConfig, (draft) => {
+ const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === currentTool?.provider_id)
+ if (tool)
+ (tool as AgentTool).credential_id = credentialId
+ })
+ setCurrentTool({
+ ...currentTool,
+ credential_id: credentialId,
+ } as any)
+ setModelConfig(newModelConfig)
+ formattingChangedDispatcher()
+ }, [currentTool, modelConfig, setModelConfig, formattingChangedDispatcher])
+
return (
<>
{
{item.tool_label}
{!item.isDeleted && (
{item.tool_name}
@@ -232,7 +225,6 @@ const AgentTools: FC = () => {
@@ -259,7 +251,6 @@ const AgentTools: FC = () => {
{!item.notAuthor && (
{
setCurrentTool(item)
@@ -302,7 +293,7 @@ const AgentTools: FC = () => {
{item.notAuthor && (
{
setCurrentTool(item)
- setShowSettingAuth(true)
+ setIsShowSettingTool(true)
}}>
{t('tools.notAuthorized')}
@@ -322,21 +313,8 @@ const AgentTools: FC = () => {
isModel={currentTool?.collection?.type === CollectionType.model}
onSave={handleToolSettingChange}
onHide={() => setIsShowSettingTool(false)}
- />
- )}
- {isShowSettingAuth && (
- setShowSettingAuth(false)}
- onSaved={async (value) => {
- await updateBuiltInToolCredential((currentCollection as any).name, value)
- Toast.notify({
- type: 'success',
- message: t('common.api.actionSuccess'),
- })
- handleToolAuthSetting(currentTool)
- setShowSettingAuth(false)
- }}
+ credentialId={currentTool?.credential_id}
+ onAuthorizationItemClick={handleAuthorizationItemClick}
/>
)}
>
diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx
index 1ad814c6e9..92f1525bd5 100644
--- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx
+++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx
@@ -14,7 +14,6 @@ import Icon from '@/app/components/plugins/card/base/card-icon'
import OrgInfo from '@/app/components/plugins/card/base/org-info'
import Description from '@/app/components/plugins/card/base/description'
import TabSlider from '@/app/components/base/tab-slider-plain'
-
import Button from '@/app/components/base/button'
import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
@@ -25,6 +24,10 @@ import I18n from '@/context/i18n'
import { getLanguage } from '@/i18n/language'
import cn from '@/utils/classnames'
import type { ToolWithProvider } from '@/app/components/workflow/types'
+import {
+ AuthCategory,
+ PluginAuthInAgent,
+} from '@/app/components/plugins/plugin-auth'
type Props = {
showBackButton?: boolean
@@ -36,6 +39,8 @@ type Props = {
readonly?: boolean
onHide: () => void
onSave?: (value: Record) => void
+ credentialId?: string
+ onAuthorizationItemClick?: (id: string) => void
}
const SettingBuiltInTool: FC = ({
@@ -48,6 +53,8 @@ const SettingBuiltInTool: FC = ({
readonly,
onHide,
onSave,
+ credentialId,
+ onAuthorizationItemClick,
}) => {
const { locale } = useContext(I18n)
const language = getLanguage(locale)
@@ -197,8 +204,20 @@ const SettingBuiltInTool: FC = ({
{currTool?.label[language]}
{!!currTool?.description[language] && (
-
+
)}
+ {
+ collection.allow_delete && collection.type === CollectionType.builtIn && (
+
+ )
+ }
{/* form */}
diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx
index 579b7c4d64..98b23e5379 100644
--- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx
+++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx
@@ -6,8 +6,8 @@ import { useContext } from 'use-context-selector'
import { useTranslation } from 'react-i18next'
import cn from '@/utils/classnames'
import {
- Clipboard,
- ClipboardCheck,
+ Copy,
+ CopyCheck,
} from '@/app/components/base/icons/src/vender/line/files'
import PromptEditor from '@/app/components/base/prompt-editor'
import type { ExternalDataTool } from '@/models/common'
@@ -81,13 +81,13 @@ const Editor: FC
= ({
{!isCopied
? (
- {
+ {
copy(value)
setIsCopied(true)
}} />
)
: (
-
+
)}
diff --git a/web/app/components/app/configuration/config/config-audio.tsx b/web/app/components/app/configuration/config/config-audio.tsx
new file mode 100644
index 0000000000..5600f8cbb6
--- /dev/null
+++ b/web/app/components/app/configuration/config/config-audio.tsx
@@ -0,0 +1,78 @@
+'use client'
+import type { FC } from 'react'
+import React, { useCallback } from 'react'
+import { useTranslation } from 'react-i18next'
+import produce from 'immer'
+import { useContext } from 'use-context-selector'
+
+import { Microphone01 } from '@/app/components/base/icons/src/vender/features'
+import Tooltip from '@/app/components/base/tooltip'
+import ConfigContext from '@/context/debug-configuration'
+import { SupportUploadFileTypes } from '@/app/components/workflow/types'
+import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks'
+import Switch from '@/app/components/base/switch'
+
+const ConfigAudio: FC = () => {
+ const { t } = useTranslation()
+ const file = useFeatures(s => s.features.file)
+ const featuresStore = useFeaturesStore()
+ const { isShowAudioConfig } = useContext(ConfigContext)
+
+ const isAudioEnabled = file?.allowed_file_types?.includes(SupportUploadFileTypes.audio) ?? false
+
+ const handleChange = useCallback((value: boolean) => {
+ const {
+ features,
+ setFeatures,
+ } = featuresStore!.getState()
+
+ const newFeatures = produce(features, (draft) => {
+ if (value) {
+ draft.file!.allowed_file_types = Array.from(new Set([
+ ...(draft.file?.allowed_file_types || []),
+ SupportUploadFileTypes.audio,
+ ]))
+ }
+ else {
+ draft.file!.allowed_file_types = draft.file!.allowed_file_types?.filter(
+ type => type !== SupportUploadFileTypes.audio,
+ )
+ }
+ if (draft.file)
+ draft.file.enabled = (draft.file.allowed_file_types?.length ?? 0) > 0
+ })
+ setFeatures(newFeatures)
+ }, [featuresStore])
+
+ if (!isShowAudioConfig)
+ return null
+
+ return (
+
+
+
+
{t('appDebug.feature.audioUpload.title')}
+
+ {t('appDebug.feature.audioUpload.description')}
+
+ }
+ />
+
+
+
+ )
+}
+export default React.memo(ConfigAudio)
diff --git a/web/app/components/app/configuration/config/index.tsx b/web/app/components/app/configuration/config/index.tsx
index dc2095502e..d0375c6de9 100644
--- a/web/app/components/app/configuration/config/index.tsx
+++ b/web/app/components/app/configuration/config/index.tsx
@@ -8,6 +8,7 @@ import DatasetConfig from '../dataset-config'
import HistoryPanel from '../config-prompt/conversation-history/history-panel'
import ConfigVision from '../config-vision'
import ConfigDocument from './config-document'
+import ConfigAudio from './config-audio'
import AgentTools from './agent/agent-tools'
import ConfigContext from '@/context/debug-configuration'
import ConfigPrompt from '@/app/components/app/configuration/config-prompt'
@@ -85,6 +86,8 @@ const Config: FC = () => {
+
+
{/* Chat History */}
{isAdvancedMode && isChatApp && modelModeType === ModelModeType.completion && (
{
const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision)
const isShowDocumentConfig = !!currModel?.features?.includes(ModelFeatureEnum.document)
+ const isShowAudioConfig = !!currModel?.features?.includes(ModelFeatureEnum.audio)
const isAllowVideoUpload = !!currModel?.features?.includes(ModelFeatureEnum.video)
// *** web app features ***
const featuresData: FeaturesData = useMemo(() => {
@@ -920,6 +921,7 @@ const Configuration: FC = () => {
setVisionConfig: handleSetVisionConfig,
isAllowVideoUpload,
isShowDocumentConfig,
+ isShowAudioConfig,
rerankSettingModalOpen,
setRerankSettingModalOpen,
}}
diff --git a/web/app/components/app/configuration/prompt-value-panel/index.tsx b/web/app/components/app/configuration/prompt-value-panel/index.tsx
index e509ee50e4..b36bf8848a 100644
--- a/web/app/components/app/configuration/prompt-value-panel/index.tsx
+++ b/web/app/components/app/configuration/prompt-value-panel/index.tsx
@@ -177,7 +177,7 @@ const PromptValuePanel: FC = ({
{t('common.operation.clear')}
{canNotRun && (
-
+
diff --git a/web/app/components/app/type-selector/index.tsx b/web/app/components/app/type-selector/index.tsx
index a57bac20db..99a76d7ac7 100644
--- a/web/app/components/app/type-selector/index.tsx
+++ b/web/app/components/app/type-selector/index.tsx
@@ -65,6 +65,44 @@ const AppTypeSelector = ({ value, onChange }: AppSelectorProps) => {
export default AppTypeSelector
+type AppTypeIconProps = {
+ type: AppMode
+ style?: React.CSSProperties
+ className?: string
+ wrapperClassName?: string
+}
+
+export const AppTypeIcon = React.memo(({ type, className, wrapperClassName, style }: AppTypeIconProps) => {
+ const wrapperClassNames = cn('inline-flex h-5 w-5 items-center justify-center rounded-md border border-divider-regular', wrapperClassName)
+ const iconClassNames = cn('h-3.5 w-3.5 text-components-avatar-shape-fill-stop-100', className)
+ if (type === 'chat') {
+ return
+
+
+ }
+ if (type === 'agent-chat') {
+ return
+
+
+ }
+ if (type === 'advanced-chat') {
+ return
+
+
+ }
+ if (type === 'workflow') {
+ return
+
+
+ }
+ if (type === 'completion') {
+ return
+
+
+ }
+ return null
+})
+
function AppTypeSelectTrigger({ values }: { values: AppSelectorProps['value'] }) {
const { t } = useTranslation()
if (!values || values.length === 0) {
@@ -108,44 +146,6 @@ function AppTypeSelectorItem({ checked, type, onClick }: AppTypeSelectorItemProp
}
-type AppTypeIconProps = {
- type: AppMode
- style?: React.CSSProperties
- className?: string
- wrapperClassName?: string
-}
-
-export function AppTypeIcon({ type, className, wrapperClassName, style }: AppTypeIconProps) {
- const wrapperClassNames = cn('inline-flex h-5 w-5 items-center justify-center rounded-md border border-divider-regular', wrapperClassName)
- const iconClassNames = cn('h-3.5 w-3.5 text-components-avatar-shape-fill-stop-100', className)
- if (type === 'chat') {
- return
-
-
- }
- if (type === 'agent-chat') {
- return
-
-
- }
- if (type === 'advanced-chat') {
- return
-
-
- }
- if (type === 'workflow') {
- return
-
-
- }
- if (type === 'completion') {
- return
-
-
- }
- return null
-}
-
type AppTypeLabelProps = {
type: AppMode
className?: string
diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/components/apps/app-card.tsx
similarity index 94%
rename from web/app/(commonLayout)/apps/AppCard.tsx
rename to web/app/components/apps/app-card.tsx
index f50cc10520..bfb7813bf4 100644
--- a/web/app/(commonLayout)/apps/AppCard.tsx
+++ b/web/app/components/apps/app-card.tsx
@@ -1,16 +1,14 @@
'use client'
+import React, { useCallback, useEffect, useMemo, useState } from 'react'
import { useContext, useContextSelector } from 'use-context-selector'
import { useRouter } from 'next/navigation'
-import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react'
import cn from '@/utils/classnames'
import type { App } from '@/types/app'
-import Confirm from '@/app/components/base/confirm'
import Toast, { ToastContext } from '@/app/components/base/toast'
import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps'
-import DuplicateAppModal from '@/app/components/app/duplicate-modal'
import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal'
import AppIcon from '@/app/components/base/app-icon'
import AppsContext, { useAppContext } from '@/context/app-context'
@@ -22,21 +20,37 @@ import { getRedirection } from '@/utils/app-redirection'
import { useProviderContext } from '@/context/provider-context'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal'
-import EditAppModal from '@/app/components/explore/create-app-modal'
-import SwitchAppModal from '@/app/components/app/switch-app-modal'
import type { Tag } from '@/app/components/base/tag-management/constant'
import TagSelector from '@/app/components/base/tag-management/selector'
import type { EnvironmentVariable } from '@/app/components/workflow/types'
-import DSLExportConfirmModal from '@/app/components/workflow/dsl-export-confirm-modal'
import { fetchWorkflowDraft } from '@/service/workflow'
import { fetchInstalledAppList } from '@/service/explore'
import { AppTypeIcon } from '@/app/components/app/type-selector'
import Tooltip from '@/app/components/base/tooltip'
-import AccessControl from '@/app/components/app/app-access-control'
import { AccessMode } from '@/models/access-control'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { formatTime } from '@/utils/time'
import { useGetUserCanAccessApp } from '@/service/access-control'
+import dynamic from 'next/dynamic'
+
+const EditAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), {
+ ssr: false,
+})
+const DuplicateAppModal = dynamic(() => import('@/app/components/app/duplicate-modal'), {
+ ssr: false,
+})
+const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), {
+ ssr: false,
+})
+const Confirm = dynamic(() => import('@/app/components/base/confirm'), {
+ ssr: false,
+})
+const DSLExportConfirmModal = dynamic(() => import('@/app/components/workflow/dsl-export-confirm-modal'), {
+ ssr: false,
+})
+const AccessControl = dynamic(() => import('@/app/components/app/app-access-control'), {
+ ssr: false,
+})
export type AppCardProps = {
app: App
@@ -88,6 +102,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
icon_background,
description,
use_icon_as_answer_icon,
+ max_active_requests,
}) => {
try {
await updateAppInfo({
@@ -98,6 +113,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
icon_background,
description,
use_icon_as_answer_icon,
+ max_active_requests,
})
setShowEditModal(false)
notify({
@@ -306,7 +322,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
const EditTimeText = useMemo(() => {
const timeText = formatTime({
date: (app.updated_at || app.created_at) * 1000,
- dateFormat: 'MM/DD/YYYY h:mm',
+ dateFormat: `${t('datasetDocuments.segment.dateTimeFormat')}`,
})
return `${t('datasetDocuments.segment.editedAt')} ${timeText}`
// eslint-disable-next-line react-hooks/exhaustive-deps
@@ -432,6 +448,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
appDescription={app.description}
appMode={app.mode}
appUseIconAsAnswerIcon={app.use_icon_as_answer_icon}
+ max_active_requests={app.max_active_requests ?? null}
show={showEditModal}
onConfirm={onEdit}
onHide={() => setShowEditModal(false)}
@@ -480,4 +497,4 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
)
}
-export default AppCard
+export default React.memo(AppCard)
diff --git a/web/app/components/apps/empty.tsx b/web/app/components/apps/empty.tsx
new file mode 100644
index 0000000000..e6b52294a2
--- /dev/null
+++ b/web/app/components/apps/empty.tsx
@@ -0,0 +1,35 @@
+import React from 'react'
+import { useTranslation } from 'react-i18next'
+
+const DefaultCards = React.memo(() => {
+ const renderArray = Array.from({ length: 36 })
+ return (
+ <>
+ {
+ renderArray.map((_, index) => (
+
+ ))
+ }
+ >
+ )
+})
+
+const Empty = () => {
+ const { t } = useTranslation()
+
+ return (
+ <>
+
+
+
+ {t('app.newApp.noAppsFound')}
+
+
+ >
+ )
+}
+
+export default React.memo(Empty)
diff --git a/web/app/components/apps/footer.tsx b/web/app/components/apps/footer.tsx
new file mode 100644
index 0000000000..7bee272342
--- /dev/null
+++ b/web/app/components/apps/footer.tsx
@@ -0,0 +1,46 @@
+import React from 'react'
+import Link from 'next/link'
+import { RiDiscordFill, RiGithubFill } from '@remixicon/react'
+import { useTranslation } from 'react-i18next'
+
+type CustomLinkProps = {
+ href: string
+ children: React.ReactNode
+}
+
+const CustomLink = React.memo(({
+ href,
+ children,
+}: CustomLinkProps) => {
+ return (
+
+ {children}
+
+ )
+})
+
+const Footer = () => {
+ const { t } = useTranslation()
+
+ return (
+
+ )
+}
+
+export default React.memo(Footer)
diff --git a/web/app/(commonLayout)/apps/hooks/use-apps-query-state.ts b/web/app/components/apps/hooks/use-apps-query-state.ts
similarity index 100%
rename from web/app/(commonLayout)/apps/hooks/use-apps-query-state.ts
rename to web/app/components/apps/hooks/use-apps-query-state.ts
diff --git a/web/app/(commonLayout)/apps/hooks/use-dsl-drag-drop.ts b/web/app/components/apps/hooks/use-dsl-drag-drop.ts
similarity index 97%
rename from web/app/(commonLayout)/apps/hooks/use-dsl-drag-drop.ts
rename to web/app/components/apps/hooks/use-dsl-drag-drop.ts
index 96942ec54e..dda5773062 100644
--- a/web/app/(commonLayout)/apps/hooks/use-dsl-drag-drop.ts
+++ b/web/app/components/apps/hooks/use-dsl-drag-drop.ts
@@ -2,7 +2,7 @@ import { useEffect, useState } from 'react'
type DSLDragDropHookProps = {
onDSLFileDropped: (file: File) => void
- containerRef: React.RefObject
+ containerRef: React.RefObject
enabled?: boolean
}
diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx
new file mode 100644
index 0000000000..be81a77dc3
--- /dev/null
+++ b/web/app/components/apps/index.tsx
@@ -0,0 +1,26 @@
+'use client'
+import { useEducationInit } from '@/app/education-apply/hooks'
+import { useGlobalPublicStore } from '@/context/global-public-context'
+import List from './list'
+import Footer from './footer'
+import useDocumentTitle from '@/hooks/use-document-title'
+import { useTranslation } from 'react-i18next'
+
+const Apps = () => {
+ const { t } = useTranslation()
+ const { systemFeatures } = useGlobalPublicStore()
+
+ useDocumentTitle(t('common.menus.apps'))
+ useEducationInit()
+
+ return (
+
+
+ {!systemFeatures.branding.enabled && (
+
+ )}
+
+ )
+}
+
+export default Apps
diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/components/apps/list.tsx
similarity index 90%
rename from web/app/(commonLayout)/apps/Apps.tsx
rename to web/app/components/apps/list.tsx
index 2aa192fb02..359eaeabd4 100644
--- a/web/app/(commonLayout)/apps/Apps.tsx
+++ b/web/app/components/apps/list.tsx
@@ -15,8 +15,8 @@ import {
RiMessage3Line,
RiRobot3Line,
} from '@remixicon/react'
-import AppCard from './AppCard'
-import NewAppCard from './NewAppCard'
+import AppCard from './app-card'
+import NewAppCard from './new-app-card'
import useAppsQueryState from './hooks/use-apps-query-state'
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
import type { AppListResponse } from '@/models/app'
@@ -28,10 +28,17 @@ import TabSliderNew from '@/app/components/base/tab-slider-new'
import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
import Input from '@/app/components/base/input'
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
-import TagManagementModal from '@/app/components/base/tag-management'
import TagFilter from '@/app/components/base/tag-management/filter'
import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label'
-import CreateFromDSLModal from '@/app/components/app/create-from-dsl-modal'
+import dynamic from 'next/dynamic'
+import Empty from './empty'
+
+const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), {
+ ssr: false,
+})
+const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-from-dsl-modal'), {
+ ssr: false,
+})
const getKey = (
pageIndex: number,
@@ -57,7 +64,7 @@ const getKey = (
return null
}
-const Apps = () => {
+const List = () => {
const { t } = useTranslation()
const router = useRouter()
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator } = useAppContext()
@@ -209,7 +216,7 @@ const Apps = () => {
:
{isCurrentWorkspaceEditor
&& }
-
+
}
{isCurrentWorkspaceEditor && (
@@ -248,22 +255,4 @@ const Apps = () => {
)
}
-export default Apps
-
-function NoAppsFound() {
- const { t } = useTranslation()
- function renderDefaultCard() {
- const defaultCards = Array.from({ length: 36 }, (_, index) => (
-
- ))
- return defaultCards
- }
- return (
- <>
- {renderDefaultCard()}
-
- {t('app.newApp.noAppsFound')}
-
- >
- )
-}
+export default List
diff --git a/web/app/(commonLayout)/apps/NewAppCard.tsx b/web/app/components/apps/new-app-card.tsx
similarity index 56%
rename from web/app/(commonLayout)/apps/NewAppCard.tsx
rename to web/app/components/apps/new-app-card.tsx
index 0b42577ee3..451d2ae326 100644
--- a/web/app/(commonLayout)/apps/NewAppCard.tsx
+++ b/web/app/components/apps/new-app-card.tsx
@@ -1,32 +1,38 @@
'use client'
-import { useMemo, useState } from 'react'
+import React, { useMemo, useState } from 'react'
import {
useRouter,
useSearchParams,
} from 'next/navigation'
import { useTranslation } from 'react-i18next'
-import CreateAppTemplateDialog from '@/app/components/app/create-app-dialog'
-import CreateAppModal from '@/app/components/app/create-app-modal'
-import CreateFromDSLModal, { CreateFromDSLModalTab } from '@/app/components/app/create-from-dsl-modal'
+import { CreateFromDSLModalTab } from '@/app/components/app/create-from-dsl-modal'
import { useProviderContext } from '@/context/provider-context'
import { FileArrow01, FilePlus01, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files'
import cn from '@/utils/classnames'
+import dynamic from 'next/dynamic'
+
+const CreateAppModal = dynamic(() => import('@/app/components/app/create-app-modal'), {
+ ssr: false,
+})
+const CreateAppTemplateDialog = dynamic(() => import('@/app/components/app/create-app-dialog'), {
+ ssr: false,
+})
+const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-from-dsl-modal'), {
+ ssr: false,
+})
export type CreateAppCardProps = {
className?: string
onSuccess?: () => void
+ ref: React.RefObject
}
-const CreateAppCard = (
- {
- ref,
- className,
- onSuccess,
- }: CreateAppCardProps & {
- ref: React.RefObject;
- },
-) => {
+const CreateAppCard = ({
+ ref,
+ className,
+ onSuccess,
+}: CreateAppCardProps) => {
const { t } = useTranslation()
const { onPlanInfoChanged } = useProviderContext()
const searchParams = useSearchParams()
@@ -67,52 +73,58 @@ const CreateAppCard = (
- setShowNewAppModal(false)}
- onSuccess={() => {
- onPlanInfoChanged()
- if (onSuccess)
- onSuccess()
- }}
- onCreateFromTemplate={() => {
- setShowNewAppTemplateDialog(true)
- setShowNewAppModal(false)
- }}
- />
- setShowNewAppTemplateDialog(false)}
- onSuccess={() => {
- onPlanInfoChanged()
- if (onSuccess)
- onSuccess()
- }}
- onCreateFromBlank={() => {
- setShowNewAppModal(true)
- setShowNewAppTemplateDialog(false)
- }}
- />
- {
- setShowCreateFromDSLModal(false)
+ {showNewAppModal && (
+ setShowNewAppModal(false)}
+ onSuccess={() => {
+ onPlanInfoChanged()
+ if (onSuccess)
+ onSuccess()
+ }}
+ onCreateFromTemplate={() => {
+ setShowNewAppTemplateDialog(true)
+ setShowNewAppModal(false)
+ }}
+ />
+ )}
+ {showNewAppTemplateDialog && (
+ setShowNewAppTemplateDialog(false)}
+ onSuccess={() => {
+ onPlanInfoChanged()
+ if (onSuccess)
+ onSuccess()
+ }}
+ onCreateFromBlank={() => {
+ setShowNewAppModal(true)
+ setShowNewAppTemplateDialog(false)
+ }}
+ />
+ )}
+ {showCreateFromDSLModal && (
+ {
+ setShowCreateFromDSLModal(false)
- if (dslUrl)
- replace('/')
- }}
- activeTab={activeTab}
- dslUrl={dslUrl}
- onSuccess={() => {
- onPlanInfoChanged()
- if (onSuccess)
- onSuccess()
- }}
- />
+ if (dslUrl)
+ replace('/')
+ }}
+ activeTab={activeTab}
+ dslUrl={dslUrl}
+ onSuccess={() => {
+ onPlanInfoChanged()
+ if (onSuccess)
+ onSuccess()
+ }}
+ />
+ )}
)
}
CreateAppCard.displayName = 'CreateAppCard'
-export default CreateAppCard
-export { CreateAppCard }
+
+export default React.memo(CreateAppCard)
diff --git a/web/app/components/base/app-icon/index.tsx b/web/app/components/base/app-icon/index.tsx
index 003d929c8c..b4724ca5de 100644
--- a/web/app/components/base/app-icon/index.tsx
+++ b/web/app/components/base/app-icon/index.tsx
@@ -1,5 +1,6 @@
'use client'
+import React from 'react'
import type { FC } from 'react'
import { init } from 'emoji-mart'
import data from '@emoji-mart/data'
@@ -71,4 +72,4 @@ const AppIcon: FC