refactor(api): Simplify the constructor of `GraphEngine`

Move most contextual arguments into `GraphInitParams`.
pull/22621/head
QuantumGhost 10 months ago
parent d99ad77837
commit f900a92ee7

@ -3,7 +3,7 @@ import logging
import queue import queue
import time import time
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor, wait from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy from copy import copy, deepcopy
from datetime import UTC, datetime from datetime import UTC, datetime
@ -91,19 +91,9 @@ class GraphEngine:
def __init__( def __init__(
self, self,
tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph, graph: Graph,
graph_config: Mapping[str, Any], graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None, thread_pool_id: Optional[str] = None,
) -> None: ) -> None:
thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT
@ -126,17 +116,7 @@ class GraphEngine:
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
self.graph = graph self.graph = graph
self.init_params = GraphInitParams( self.init_params = graph_init_params
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
)
self.graph_runtime_state = graph_runtime_state self.graph_runtime_state = graph_runtime_state

@ -32,9 +32,6 @@ class BaseNode:
self.id = id self.id = id
self.tenant_id = graph_init_params.tenant_id self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from self.invoke_from = graph_init_params.invoke_from
@ -43,6 +40,7 @@ class BaseNode:
self.graph_runtime_state = graph_runtime_state self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id self.thread_pool_id = thread_pool_id
self._init_params = graph_init_params
node_id = config.get("id") node_id = config.get("id")
if not node_id: if not node_id:

@ -1,6 +1,5 @@
import contextvars import contextvars
import logging import logging
import time
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait from concurrent.futures import Future, wait
@ -137,15 +136,13 @@ class IterationNode(BaseNode):
inputs = {"iterator_selector": iterator_list_value} inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
if not self._node_data.start_node_id: if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self._node_data.start_node_id root_node_id = self._node_data.start_node_id
# init graph # init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) iteration_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=root_node_id)
if not iteration_graph: if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found") raise IterationGraphNotFoundError("iteration graph not found")
@ -163,19 +160,9 @@ class IterationNode(BaseNode):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph, graph=iteration_graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, graph_init_params=self._init_params,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )

@ -1,11 +1,9 @@
import json import json
import logging import logging
import time
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal, Optional, cast from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config
from core.variables import ( from core.variables import (
IntegerSegment, IntegerSegment,
Segment, Segment,
@ -91,7 +89,7 @@ class LoopNode(BaseNode):
raise ValueError(f"field start_node_id in loop {self.node_id} not found") raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph # Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id) loop_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph: if not loop_graph:
raise ValueError("loop graph not found") raise ValueError("loop graph not found")
@ -127,19 +125,9 @@ class LoopNode(BaseNode):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=loop_graph, graph=loop_graph,
graph_config=self.graph_config, graph_init_params=self._init_params,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )

@ -71,7 +71,7 @@ class WorkflowEntry:
# init workflow run state # init workflow run state
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
self.graph_engine = GraphEngine( graph_init_params = GraphInitParams(
tenant_id=tenant_id, tenant_id=tenant_id,
app_id=app_id, app_id=app_id,
workflow_type=workflow_type, workflow_type=workflow_type,
@ -80,11 +80,14 @@ class WorkflowEntry:
user_from=user_from, user_from=user_from,
invoke_from=invoke_from, invoke_from=invoke_from,
call_depth=call_depth, call_depth=call_depth,
graph=graph,
graph_config=graph_config, graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
self.graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
thread_pool_id=thread_pool_id, thread_pool_id=thread_pool_id,
) )

@ -1,4 +1,3 @@
import time
from decimal import Decimal from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage

@ -178,6 +178,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
graph_runtime_state = GraphRuntimeState( graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool, variable_pool=variable_pool,
) )
init_params = GraphInitParams(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
workflow_type=WorkflowType.WORKFLOW, workflow_type=WorkflowType.WORKFLOW,
@ -187,11 +188,14 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=init_params,
)
def llm_generator(self): def llm_generator(self):
contents = ["hi", "bye", "good morning"] contents = ["hi", "bye", "good morning"]
@ -307,6 +311,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
graph_runtime_state = GraphRuntimeState( graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool, variable_pool=variable_pool,
) )
graph_init_params = GraphInitParams(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
workflow_type=WorkflowType.CHAT, workflow_type=WorkflowType.CHAT,
@ -316,11 +321,14 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
# print("") # print("")
@ -489,6 +497,7 @@ def test_run_branch(mock_close, mock_remove):
graph_runtime_state = GraphRuntimeState( graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool, variable_pool=variable_pool,
) )
graph_init_params = GraphInitParams(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
workflow_type=WorkflowType.CHAT, workflow_type=WorkflowType.CHAT,
@ -498,11 +507,14 @@ def test_run_branch(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
# print("") # print("")
@ -829,6 +841,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
graph_runtime_state = GraphRuntimeState( graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool, variable_pool=variable_pool,
) )
graph_init_params = GraphInitParams(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
workflow_type=WorkflowType.CHAT, workflow_type=WorkflowType.CHAT,
@ -838,11 +851,14 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
def qc_generator(self): def qc_generator(self):
yield RunCompletedEvent( yield RunCompletedEvent(

@ -1,4 +1,3 @@
import time
import uuid import uuid
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -71,7 +70,9 @@ def test_execute_answer():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config, config=node_config,
) )

Loading…
Cancel
Save