From c6590aef1e13d8556f8be6bd9f3e7ed44269d5e5 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 3 Jul 2025 16:55:17 +0800 Subject: [PATCH] refactor(graph_engine): Take GraphRuntimeState out of GraphEngine Signed-off-by: -LAN- --- api/core/workflow/graph_engine/graph_engine.py | 4 ++-- .../workflow/nodes/iteration/iteration_node.py | 7 ++++++- api/core/workflow/nodes/loop/loop_node.py | 7 ++++++- api/core/workflow/workflow_entry.py | 3 ++- .../workflow/graph_engine/test_graph_engine.py | 14 ++++++++++---- 5 files changed, 26 insertions(+), 9 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 61a7a26652..5a2915e2d3 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -103,7 +103,7 @@ class GraphEngine: call_depth: int, graph: Graph, graph_config: Mapping[str, Any], - variable_pool: VariablePool, + graph_runtime_state: GraphRuntimeState, max_execution_steps: int, max_execution_time: int, thread_pool_id: Optional[str] = None, @@ -140,7 +140,7 @@ class GraphEngine: call_depth=call_depth, ) - self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + self.graph_runtime_state = graph_runtime_state self.max_execution_steps = max_execution_steps self.max_execution_time = max_execution_time diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 151efc28ec..98246d4249 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -133,8 +133,13 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.add([self.node_id, "item"], iterator_list_value[0]) # init graph engine + import time + + from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_engine = GraphEngine( tenant_id=self.tenant_id, app_id=self.app_id, @@ -146,7 +151,7 @@ class IterationNode(BaseNode[IterationNodeData]): call_depth=self.workflow_call_depth, graph=iteration_graph, graph_config=graph_config, - variable_pool=variable_pool, + graph_runtime_state=graph_runtime_state, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, thread_pool_id=self.thread_pool_id, diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 368d662a75..cca2806a23 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -101,8 +101,13 @@ class LoopNode(BaseNode[LoopNodeData]): loop_variable_selectors[loop_variable.label] = variable_selector inputs[loop_variable.label] = processed_segment.value + import time + + from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_engine = GraphEngine( tenant_id=self.tenant_id, app_id=self.app_id, @@ -114,7 +119,7 @@ class LoopNode(BaseNode[LoopNodeData]): call_depth=self.workflow_call_depth, graph=loop_graph, graph_config=self.graph_config, - variable_pool=variable_pool, + graph_runtime_state=graph_runtime_state, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, thread_pool_id=self.thread_pool_id, diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index c0e98db3db..2868dcb7de 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -69,6 +69,7 @@ class WorkflowEntry: raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) # init workflow run state + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) self.graph_engine = GraphEngine( tenant_id=tenant_id, app_id=app_id, @@ -80,7 +81,7 @@ class WorkflowEntry: call_depth=call_depth, graph=graph, graph_config=graph_config, - variable_pool=variable_pool, + graph_runtime_state=graph_runtime_state, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, thread_pool_id=thread_pool_id, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 7535ec4866..c288a5fa13 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -1,3 +1,4 @@ +import time from unittest.mock import patch import pytest @@ -19,6 +20,7 @@ from core.workflow.graph_engine.entities.event import ( NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.code.code_node import CodeNode @@ -172,6 +174,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_engine = GraphEngine( tenant_id="111", app_id="222", @@ -183,7 +186,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): invoke_from=InvokeFrom.WEB_APP, call_depth=0, graph=graph, - variable_pool=variable_pool, + graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) @@ -299,6 +302,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): user_inputs={}, ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_engine = GraphEngine( tenant_id="111", app_id="222", @@ -310,7 +314,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): invoke_from=InvokeFrom.WEB_APP, call_depth=0, graph=graph, - variable_pool=variable_pool, + graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) @@ -479,6 +483,7 @@ def test_run_branch(mock_close, mock_remove): user_inputs={"uid": "takato"}, ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_engine = GraphEngine( tenant_id="111", app_id="222", @@ -490,7 +495,7 @@ def test_run_branch(mock_close, mock_remove): invoke_from=InvokeFrom.WEB_APP, call_depth=0, graph=graph, - variable_pool=variable_pool, + graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) @@ -813,6 +818,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_engine = GraphEngine( tenant_id="111", app_id="222", @@ -824,7 +830,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): invoke_from=InvokeFrom.WEB_APP, call_depth=0, graph=graph, - variable_pool=variable_pool, + graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, )