From f900a92ee7bc76a028fb989dcf02f638eb2d5b04 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 14:20:32 +0800 Subject: [PATCH] refactor(api): Simplify the constructor of `GraphEngine` Move most contextual arguments into `GraphInitParams`. --- .../workflow/graph_engine/graph_engine.py | 26 ++------------- api/core/workflow/nodes/base/node.py | 4 +-- .../nodes/iteration/iteration_node.py | 17 ++-------- api/core/workflow/nodes/loop/loop_node.py | 16 ++-------- api/core/workflow/workflow_entry.py | 9 ++++-- .../entities/test_graph_runtime_state.py | 1 - .../graph_engine/test_graph_engine.py | 32 ++++++++++++++----- .../core/workflow/nodes/answer/test_answer.py | 5 +-- 8 files changed, 41 insertions(+), 69 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b315129763..920d30b39b 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -3,7 +3,7 @@ import logging import queue import time import uuid -from collections.abc import Generator, Mapping +from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy from datetime import UTC, datetime @@ -91,19 +91,9 @@ class GraphEngine: def __init__( self, - tenant_id: str, - app_id: str, - workflow_type: WorkflowType, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - call_depth: int, graph: Graph, - graph_config: Mapping[str, Any], + graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, - max_execution_steps: int, - max_execution_time: int, thread_pool_id: Optional[str] = None, ) -> None: thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT @@ -126,17 +116,7 @@ class GraphEngine: GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool self.graph = graph - self.init_params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, - workflow_type=workflow_type, - workflow_id=workflow_id, - graph_config=graph_config, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, - call_depth=call_depth, - ) + self.init_params = graph_init_params self.graph_runtime_state = graph_runtime_state diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index fb5ec55453..b7a5e3eeec 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -32,9 +32,6 @@ class BaseNode: self.id = id self.tenant_id = graph_init_params.tenant_id self.app_id = graph_init_params.app_id - self.workflow_type = graph_init_params.workflow_type - self.workflow_id = graph_init_params.workflow_id - self.graph_config = graph_init_params.graph_config self.user_id = graph_init_params.user_id self.user_from = graph_init_params.user_from self.invoke_from = graph_init_params.invoke_from @@ -43,6 +40,7 @@ class BaseNode: self.graph_runtime_state = graph_runtime_state self.previous_node_id = previous_node_id self.thread_pool_id = thread_pool_id + self._init_params = graph_init_params node_id = config.get("id") if not node_id: diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 84a5732fdd..d07a9971b9 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,6 +1,5 @@ import contextvars import logging -import time import uuid from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, wait @@ -137,15 +136,13 @@ class IterationNode(BaseNode): inputs = {"iterator_selector": iterator_list_value} - graph_config = self.graph_config - if not self._node_data.start_node_id: raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") root_node_id = self._node_data.start_node_id # init graph - iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) + iteration_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=root_node_id) if not iteration_graph: raise IterationGraphNotFoundError("iteration graph not found") @@ -163,19 +160,9 @@ class IterationNode(BaseNode): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) graph_engine = GraphEngine( - tenant_id=self.tenant_id, - app_id=self.app_id, - workflow_type=self.workflow_type, - workflow_id=self.workflow_id, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, - call_depth=self.workflow_call_depth, graph=iteration_graph, - graph_config=graph_config, graph_runtime_state=graph_runtime_state, - max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + graph_init_params=self._init_params, thread_pool_id=self.thread_pool_id, ) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 53ac04d6dc..cd12e809d5 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,11 +1,9 @@ import json import logging -import time from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Literal, Optional, cast -from configs import dify_config from core.variables import ( IntegerSegment, Segment, @@ -91,7 +89,7 @@ class LoopNode(BaseNode): raise ValueError(f"field start_node_id in loop {self.node_id} not found") # Initialize graph - loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id) + loop_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=self._node_data.start_node_id) if not loop_graph: raise ValueError("loop graph not found") @@ -127,19 +125,9 @@ class LoopNode(BaseNode): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) graph_engine = GraphEngine( - tenant_id=self.tenant_id, - app_id=self.app_id, - workflow_type=self.workflow_type, - workflow_id=self.workflow_id, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, - call_depth=self.workflow_call_depth, graph=loop_graph, - graph_config=self.graph_config, + graph_init_params=self._init_params, graph_runtime_state=graph_runtime_state, - max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, thread_pool_id=self.thread_pool_id, ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 224e8036f7..a1b6f69289 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -71,7 +71,7 @@ class WorkflowEntry: # init workflow run state graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) - self.graph_engine = GraphEngine( + graph_init_params = GraphInitParams( tenant_id=tenant_id, app_id=app_id, workflow_type=workflow_type, @@ -80,11 +80,14 @@ class WorkflowEntry: user_from=user_from, invoke_from=invoke_from, call_depth=call_depth, - graph=graph, graph_config=graph_config, - graph_runtime_state=graph_runtime_state, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + ) + self.graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, thread_pool_id=thread_pool_id, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py index cf2dbcceb3..86c842b78d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py @@ -1,4 +1,3 @@ -import time from decimal import Decimal from core.model_runtime.entities.llm_entities import LLMUsage diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index c8a55538aa..408b09e30d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -178,6 +178,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, ) + init_params = GraphInitParams( tenant_id="111", app_id="222", workflow_type=WorkflowType.WORKFLOW, @@ -187,11 +188,14 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, call_depth=0, - graph=graph, - graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) + graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=init_params, + ) def llm_generator(self): contents = ["hi", "bye", "good morning"] @@ -307,6 +311,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, ) + graph_init_params = GraphInitParams( tenant_id="111", app_id="222", workflow_type=WorkflowType.CHAT, @@ -316,11 +321,14 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, call_depth=0, - graph=graph, - graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) + graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, + ) # print("") @@ -489,6 +497,7 @@ def test_run_branch(mock_close, mock_remove): graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, ) + graph_init_params = GraphInitParams( tenant_id="111", app_id="222", workflow_type=WorkflowType.CHAT, @@ -498,11 +507,14 @@ def test_run_branch(mock_close, mock_remove): user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, call_depth=0, - graph=graph, - graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) + graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, + ) # print("") @@ -829,6 +841,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, ) + graph_init_params = GraphInitParams( tenant_id="111", app_id="222", workflow_type=WorkflowType.CHAT, @@ -838,11 +851,14 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, call_depth=0, - graph=graph, - graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) + graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, + ) def qc_generator(self): yield RunCompletedEvent( diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 1ef024f46b..d9f21e1460 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -1,4 +1,3 @@ -import time import uuid from unittest.mock import MagicMock @@ -71,7 +70,9 @@ def test_execute_answer(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + ), config=node_config, )