refactor(graph_engine): Take GraphRuntimeState out of GraphEngine

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/21882/head
-LAN- 11 months ago
parent ed54bd5121
commit c6590aef1e
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -103,7 +103,7 @@ class GraphEngine:
call_depth: int, call_depth: int,
graph: Graph, graph: Graph,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
variable_pool: VariablePool, graph_runtime_state: GraphRuntimeState,
max_execution_steps: int, max_execution_steps: int,
max_execution_time: int, max_execution_time: int,
thread_pool_id: Optional[str] = None, thread_pool_id: Optional[str] = None,
@ -140,7 +140,7 @@ class GraphEngine:
call_depth=call_depth, call_depth=call_depth,
) )
self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) self.graph_runtime_state = graph_runtime_state
self.max_execution_steps = max_execution_steps self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time self.max_execution_time = max_execution_time

@ -133,8 +133,13 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0]) variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine # init graph engine
import time
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
app_id=self.app_id, app_id=self.app_id,
@ -146,7 +151,7 @@ class IterationNode(BaseNode[IterationNodeData]):
call_depth=self.workflow_call_depth, call_depth=self.workflow_call_depth,
graph=iteration_graph, graph=iteration_graph,
graph_config=graph_config, graph_config=graph_config,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,

@ -101,8 +101,13 @@ class LoopNode(BaseNode[LoopNodeData]):
loop_variable_selectors[loop_variable.label] = variable_selector loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value inputs[loop_variable.label] = processed_segment.value
import time
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
app_id=self.app_id, app_id=self.app_id,
@ -114,7 +119,7 @@ class LoopNode(BaseNode[LoopNodeData]):
call_depth=self.workflow_call_depth, call_depth=self.workflow_call_depth,
graph=loop_graph, graph=loop_graph,
graph_config=self.graph_config, graph_config=self.graph_config,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,

@ -69,6 +69,7 @@ class WorkflowEntry:
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
# init workflow run state # init workflow run state
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
self.graph_engine = GraphEngine( self.graph_engine = GraphEngine(
tenant_id=tenant_id, tenant_id=tenant_id,
app_id=app_id, app_id=app_id,
@ -80,7 +81,7 @@ class WorkflowEntry:
call_depth=call_depth, call_depth=call_depth,
graph=graph, graph=graph,
graph_config=graph_config, graph_config=graph_config,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=thread_pool_id, thread_pool_id=thread_pool_id,

@ -1,3 +1,4 @@
import time
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -19,6 +20,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.code_node import CodeNode
@ -172,6 +174,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -183,7 +186,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
@ -299,6 +302,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_inputs={}, user_inputs={},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -310,7 +314,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
@ -479,6 +483,7 @@ def test_run_branch(mock_close, mock_remove):
user_inputs={"uid": "takato"}, user_inputs={"uid": "takato"},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -490,7 +495,7 @@ def test_run_branch(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
@ -813,6 +818,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -824,7 +830,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )

Loading…
Cancel
Save