refactor(api): Simplify the constructor of `GraphEngine`

Move most contextual arguments into `GraphInitParams`.
pull/22621/head
QuantumGhost 7 months ago
parent d99ad77837
commit f900a92ee7

@ -3,7 +3,7 @@ import logging
import queue
import time
import uuid
from collections.abc import Generator, Mapping
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from datetime import UTC, datetime
@ -91,19 +91,9 @@ class GraphEngine:
def __init__(
self,
tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None,
) -> None:
thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT
@ -126,17 +116,7 @@ class GraphEngine:
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
)
self.init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state

@ -32,9 +32,6 @@ class BaseNode:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
@ -43,6 +40,7 @@ class BaseNode:
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id
self._init_params = graph_init_params
node_id = config.get("id")
if not node_id:

@ -1,6 +1,5 @@
import contextvars
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
@ -137,15 +136,13 @@ class IterationNode(BaseNode):
inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self._node_data.start_node_id
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
iteration_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=root_node_id)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
@ -163,19 +160,9 @@ class IterationNode(BaseNode):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
graph_init_params=self._init_params,
thread_pool_id=self.thread_pool_id,
)

@ -1,11 +1,9 @@
import json
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config
from core.variables import (
IntegerSegment,
Segment,
@ -91,7 +89,7 @@ class LoopNode(BaseNode):
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
loop_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@ -127,19 +125,9 @@ class LoopNode(BaseNode):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
graph_init_params=self._init_params,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,
)

@ -71,7 +71,7 @@ class WorkflowEntry:
# init workflow run state
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
self.graph_engine = GraphEngine(
graph_init_params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
@ -80,11 +80,14 @@ class WorkflowEntry:
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
graph=graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
self.graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
thread_pool_id=thread_pool_id,
)

@ -1,4 +1,3 @@
import time
from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage

@ -178,6 +178,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.WORKFLOW,
@ -187,11 +188,14 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=init_params,
)
def llm_generator(self):
contents = ["hi", "bye", "good morning"]
@ -307,6 +311,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
@ -316,11 +321,14 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
# print("")
@ -489,6 +497,7 @@ def test_run_branch(mock_close, mock_remove):
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
@ -498,11 +507,14 @@ def test_run_branch(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
# print("")
@ -829,6 +841,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
@ -838,11 +851,14 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
def qc_generator(self):
yield RunCompletedEvent(

@ -1,4 +1,3 @@
import time
import uuid
from unittest.mock import MagicMock
@ -71,7 +70,9 @@ def test_execute_answer():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config,
)

Loading…
Cancel
Save