|
|
|
|
@ -6,14 +6,15 @@ import uuid
|
|
|
|
|
from collections.abc import Generator
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, wait
|
|
|
|
|
from copy import copy, deepcopy
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
|
from typing import Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
from flask import Flask, current_app
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
from core.app.apps.exc import GenerateTaskStoppedError
|
|
|
|
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
|
|
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
|
|
|
|
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
|
|
|
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
|
|
|
|
@ -27,6 +28,7 @@ from core.workflow.graph_engine.entities.event import (
|
|
|
|
|
GraphRunPartialSucceededEvent,
|
|
|
|
|
GraphRunStartedEvent,
|
|
|
|
|
GraphRunSucceededEvent,
|
|
|
|
|
GraphRunSuspendedEvent,
|
|
|
|
|
NodeRunExceptionEvent,
|
|
|
|
|
NodeRunFailedEvent,
|
|
|
|
|
NodeRunRetrieverResourceEvent,
|
|
|
|
|
@ -53,9 +55,10 @@ from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
|
|
|
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
|
|
|
|
from core.workflow.utils import variable_utils
|
|
|
|
|
from libs.flask_utils import preserve_flask_contexts
|
|
|
|
|
from models.enums import UserFrom
|
|
|
|
|
from models.workflow import WorkflowType
|
|
|
|
|
|
|
|
|
|
from .execution_decision import DecisionParams, ExecutionDecision, ExecutionDecisionHook
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -86,6 +89,10 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
|
|
|
|
|
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _default_hook(params: DecisionParams) -> ExecutionDecision:
|
|
|
|
|
return ExecutionDecision.CONTINUE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraphEngine:
|
|
|
|
|
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
|
|
|
|
|
|
|
|
|
|
@ -95,7 +102,12 @@ class GraphEngine:
|
|
|
|
|
graph_init_params: GraphInitParams,
|
|
|
|
|
graph_runtime_state: GraphRuntimeState,
|
|
|
|
|
thread_pool_id: Optional[str] = None,
|
|
|
|
|
execution_decision_hook: ExecutionDecisionHook = _default_hook,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Create a graph from the given state.
|
|
|
|
|
|
|
|
|
|
The
|
|
|
|
|
"""
|
|
|
|
|
thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT
|
|
|
|
|
thread_pool_max_workers = 10
|
|
|
|
|
|
|
|
|
|
@ -120,8 +132,7 @@ class GraphEngine:
|
|
|
|
|
|
|
|
|
|
self.graph_runtime_state = graph_runtime_state
|
|
|
|
|
|
|
|
|
|
self.max_execution_steps = max_execution_steps
|
|
|
|
|
self.max_execution_time = max_execution_time
|
|
|
|
|
self._exec_decision_hook = execution_decision_hook
|
|
|
|
|
|
|
|
|
|
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
|
|
|
|
# trigger graph run start event
|
|
|
|
|
@ -140,12 +151,18 @@ class GraphEngine:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# run graph
|
|
|
|
|
|
|
|
|
|
next_node_to_run = self.graph.root_node_id
|
|
|
|
|
if (next_node_id := self.graph_runtime_state.node_run_state.next_node_id) is not None:
|
|
|
|
|
next_node_to_run = next_node_id
|
|
|
|
|
generator = stream_processor.process(
|
|
|
|
|
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
|
|
|
|
|
self._run(start_node_id=next_node_to_run, handle_exceptions=handle_exceptions)
|
|
|
|
|
)
|
|
|
|
|
for item in generator:
|
|
|
|
|
try:
|
|
|
|
|
yield item
|
|
|
|
|
if isinstance(item, GraphRunSuspendedEvent):
|
|
|
|
|
return
|
|
|
|
|
if isinstance(item, NodeRunFailedEvent):
|
|
|
|
|
yield GraphRunFailedEvent(
|
|
|
|
|
error=item.route_node_state.failed_reason or "Unknown error.",
|
|
|
|
|
@ -209,22 +226,22 @@ class GraphEngine:
|
|
|
|
|
parent_parallel_start_node_id: Optional[str] = None,
|
|
|
|
|
handle_exceptions: list[str] = [],
|
|
|
|
|
) -> Generator[GraphEngineEvent, None, None]:
|
|
|
|
|
# Hint: the `_run` method is used both when running a the main graph,
|
|
|
|
|
# and also running parallel branches.
|
|
|
|
|
parallel_start_node_id = None
|
|
|
|
|
if in_parallel_id:
|
|
|
|
|
parallel_start_node_id = start_node_id
|
|
|
|
|
|
|
|
|
|
next_node_id = start_node_id
|
|
|
|
|
previous_route_node_state: Optional[RouteNodeState] = None
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
# max steps reached
|
|
|
|
|
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
|
|
|
|
raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps))
|
|
|
|
|
if self.graph_runtime_state.node_run_steps > self.init_params.max_execution_steps:
|
|
|
|
|
raise GraphRunFailedError("Max steps {} reached.".format(self.init_params.max_execution_steps))
|
|
|
|
|
|
|
|
|
|
# or max execution time reached
|
|
|
|
|
if self._is_timed_out(
|
|
|
|
|
start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time
|
|
|
|
|
):
|
|
|
|
|
raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time))
|
|
|
|
|
if self.graph_runtime_state.is_timed_out(self.init_params.max_execution_time):
|
|
|
|
|
raise GraphRunFailedError("Max execution time {}s reached.".format(self.init_params.max_execution_time))
|
|
|
|
|
|
|
|
|
|
# init route node state
|
|
|
|
|
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id)
|
|
|
|
|
@ -257,6 +274,23 @@ class GraphEngine:
|
|
|
|
|
thread_pool_id=self.thread_pool_id,
|
|
|
|
|
)
|
|
|
|
|
node.init_node_data(node_config.get("data", {}))
|
|
|
|
|
# Determine if the execution should be suspended or stopped at this point.
|
|
|
|
|
# If so, yield the corresponding event.
|
|
|
|
|
#
|
|
|
|
|
# Note: Suspension is not allowed while the graph engine is running in parallel mode.
|
|
|
|
|
if in_parallel_id is None:
|
|
|
|
|
hook_result = self._exec_decision_hook(DecisionParams(next_node_instance=node))
|
|
|
|
|
if hook_result == ExecutionDecision.SUSPEND:
|
|
|
|
|
self.graph_runtime_state.record_suspend_state(next_node_id)
|
|
|
|
|
yield GraphRunSuspendedEvent(next_node_id=next_node_id)
|
|
|
|
|
return
|
|
|
|
|
elif hook_result == ExecutionDecision.STOP:
|
|
|
|
|
# TODO: STOP the execution of worklow.
|
|
|
|
|
return
|
|
|
|
|
elif hook_result == ExecutionDecision.CONTINUE:
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
raise AssertionError("unreachable statement.")
|
|
|
|
|
try:
|
|
|
|
|
# run node
|
|
|
|
|
generator = self._run_node(
|
|
|
|
|
@ -383,8 +417,8 @@ class GraphEngine:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for parallel_result in parallel_generator:
|
|
|
|
|
if isinstance(parallel_result, str):
|
|
|
|
|
final_node_id = parallel_result
|
|
|
|
|
if isinstance(parallel_result, _ParallelBranchResult):
|
|
|
|
|
final_node_id = parallel_result.final_node_id
|
|
|
|
|
else:
|
|
|
|
|
yield parallel_result
|
|
|
|
|
|
|
|
|
|
@ -409,8 +443,8 @@ class GraphEngine:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for generated_item in parallel_generator:
|
|
|
|
|
if isinstance(generated_item, str):
|
|
|
|
|
final_node_id = generated_item
|
|
|
|
|
if isinstance(generated_item, _ParallelBranchResult):
|
|
|
|
|
final_node_id = generated_item.final_node_id
|
|
|
|
|
else:
|
|
|
|
|
yield generated_item
|
|
|
|
|
|
|
|
|
|
@ -428,7 +462,7 @@ class GraphEngine:
|
|
|
|
|
in_parallel_id: Optional[str] = None,
|
|
|
|
|
parallel_start_node_id: Optional[str] = None,
|
|
|
|
|
handle_exceptions: list[str] = [],
|
|
|
|
|
) -> Generator[GraphEngineEvent | str, None, None]:
|
|
|
|
|
) -> Generator["GraphEngineEvent | _ParallelBranchResult", None, None]:
|
|
|
|
|
# if nodes has no run conditions, parallel run all nodes
|
|
|
|
|
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
|
|
|
|
if not parallel_id:
|
|
|
|
|
@ -506,7 +540,7 @@ class GraphEngine:
|
|
|
|
|
# get final node id
|
|
|
|
|
final_node_id = parallel.end_to_node_id
|
|
|
|
|
if final_node_id:
|
|
|
|
|
yield final_node_id
|
|
|
|
|
yield _ParallelBranchResult(final_node_id)
|
|
|
|
|
|
|
|
|
|
def _run_parallel_node(
|
|
|
|
|
self,
|
|
|
|
|
@ -908,7 +942,41 @@ class GraphEngine:
|
|
|
|
|
)
|
|
|
|
|
return error_result
|
|
|
|
|
|
|
|
|
|
def save(self) -> str:
|
|
|
|
|
"""save serializes the state inside this graph engine.
|
|
|
|
|
|
|
|
|
|
This method should be called when suspension of the execution is necessary.
|
|
|
|
|
"""
|
|
|
|
|
state = _GraphEngineState(init_params=self.init_params, graph_runtime_state=self.graph_runtime_state)
|
|
|
|
|
return state.model_dump_json()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def resume(
|
|
|
|
|
cls,
|
|
|
|
|
state: str,
|
|
|
|
|
graph: Graph,
|
|
|
|
|
execution_decision_hook: ExecutionDecisionHook = _default_hook,
|
|
|
|
|
) -> "GraphEngine":
|
|
|
|
|
"""`resume` continues a suspended execution."""
|
|
|
|
|
state_ = _GraphEngineState.model_validate_json(state)
|
|
|
|
|
return cls(
|
|
|
|
|
graph=graph,
|
|
|
|
|
graph_init_params=state_.init_params,
|
|
|
|
|
graph_runtime_state=state_.graph_runtime_state,
|
|
|
|
|
execution_decision_hook=execution_decision_hook,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraphRunFailedError(Exception):
|
|
|
|
|
def __init__(self, error: str):
|
|
|
|
|
self.error = error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class _ParallelBranchResult:
|
|
|
|
|
final_node_id: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _GraphEngineState(BaseModel):
|
|
|
|
|
init_params: GraphInitParams
|
|
|
|
|
graph_runtime_state: GraphRuntimeState
|
|
|
|
|
|