refactor(graph_engine): Take GraphRuntimeState out of GraphEngine

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/21882/head
-LAN- 11 months ago
parent ed54bd5121
commit c6590aef1e
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -103,7 +103,7 @@ class GraphEngine:
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
variable_pool: VariablePool,
graph_runtime_state: GraphRuntimeState,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None,
@ -140,7 +140,7 @@ class GraphEngine:
call_depth=call_depth,
)
self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
self.graph_runtime_state = graph_runtime_state
self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time

@ -133,8 +133,13 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
import time
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
@ -146,7 +151,7 @@ class IterationNode(BaseNode[IterationNodeData]):
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,

@ -101,8 +101,13 @@ class LoopNode(BaseNode[LoopNodeData]):
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
import time
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
@ -114,7 +119,7 @@ class LoopNode(BaseNode[LoopNodeData]):
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,

@ -69,6 +69,7 @@ class WorkflowEntry:
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
# init workflow run state
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
self.graph_engine = GraphEngine(
tenant_id=tenant_id,
app_id=app_id,
@ -80,7 +81,7 @@ class WorkflowEntry:
call_depth=call_depth,
graph=graph,
graph_config=graph_config,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=thread_pool_id,

@ -1,3 +1,4 @@
import time
from unittest.mock import patch
import pytest
@ -19,6 +20,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.code.code_node import CodeNode
@ -172,6 +174,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -183,7 +186,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@ -299,6 +302,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_inputs={},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -310,7 +314,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@ -479,6 +483,7 @@ def test_run_branch(mock_close, mock_remove):
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -490,7 +495,7 @@ def test_run_branch(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@ -813,6 +818,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@ -824,7 +830,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)

Loading…
Cancel
Save