@ -1,12 +1,20 @@
import logging
import uuid
from collections . abc import Generator , Mapping , Sequence
from concurrent . futures import Future , wait
from datetime import datetime , timezone
from typing import Any , cast
from queue import Empty , Queue
from typing import TYPE_CHECKING , Any , Optional , cast
from flask import Flask , current_app
from configs import dify_config
from core . model_runtime . utils . encoders import jsonable_encoder
from core . variables import IntegerSegment
from core . workflow . entities . node_entities import NodeRunMetadataKey , NodeRunResult
from core . workflow . entities . node_entities import (
NodeRunMetadataKey ,
NodeRunResult ,
)
from core . workflow . entities . variable_pool import VariablePool
from core . workflow . graph_engine . entities . event import (
BaseGraphEvent ,
BaseNodeEvent ,
@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import (
IterationRunNextEvent ,
IterationRunStartedEvent ,
IterationRunSucceededEvent ,
NodeInIterationFailedEvent ,
NodeRunFailedEvent ,
NodeRunStartedEvent ,
NodeRunStreamChunkEvent ,
NodeRunSucceededEvent ,
)
@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph
from core . workflow . nodes . base import BaseNode
from core . workflow . nodes . enums import NodeType
from core . workflow . nodes . event import NodeEvent , RunCompletedEvent
from core . workflow . nodes . iteration . entities import IterationNodeData
from core . workflow . nodes . iteration . entities import ErrorHandleMode, IterationNodeData
from models . workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING :
from core . workflow . graph_engine . graph_engine import GraphEngine
logger = logging . getLogger ( __name__ )
@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]):
_node_data_cls = IterationNodeData
_node_type = NodeType . ITERATION
@classmethod
def get_default_config ( cls , filters : Optional [ dict ] = None ) - > dict :
return {
" type " : " iteration " ,
" config " : {
" is_parallel " : False ,
" parallel_nums " : 10 ,
" error_handle_mode " : ErrorHandleMode . TERMINATED . value ,
} ,
}
def _run ( self ) - > Generator [ NodeEvent | InNodeEvent , None , None ] :
"""
Run the node .
@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool . add ( [ self . node_id , " item " ] , iterator_list_value [ 0 ] )
# init graph engine
from core . workflow . graph_engine . graph_engine import GraphEngine
from core . workflow . graph_engine . graph_engine import GraphEngine , GraphEngineThreadPool
graph_engine = GraphEngine (
tenant_id = self . tenant_id ,
@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]):
index = 0 ,
pre_iteration_output = None ,
)
outputs : list [ Any ] = [ ]
try :
for _ in range ( len ( iterator_list_value ) ) :
# run workflow
rst = graph_engine . run ( )
for event in rst :
if isinstance ( event , ( BaseNodeEvent | BaseParallelBranchEvent ) ) and not event . in_iteration_id :
event . in_iteration_id = self . node_id
if (
isinstance ( event , BaseNodeEvent )
and event . node_type == NodeType . ITERATION_START
and not isinstance ( event , NodeRunStreamChunkEvent )
) :
continue
if isinstance ( event , NodeRunSucceededEvent ) :
if event . route_node_state . node_run_result :
metadata = event . route_node_state . node_run_result . metadata
if not metadata :
metadata = { }
if NodeRunMetadataKey . ITERATION_ID not in metadata :
metadata [ NodeRunMetadataKey . ITERATION_ID ] = self . node_id
index_variable = variable_pool . get ( [ self . node_id , " index " ] )
if not isinstance ( index_variable , IntegerSegment ) :
yield RunCompletedEvent (
run_result = NodeRunResult (
status = WorkflowNodeExecutionStatus . FAILED ,
error = f " Invalid index variable type: { type ( index_variable ) } " ,
)
)
return
metadata [ NodeRunMetadataKey . ITERATION_INDEX ] = index_variable . value
event . route_node_state . node_run_result . metadata = metadata
yield event
elif isinstance ( event , BaseGraphEvent ) :
if isinstance ( event , GraphRunFailedEvent ) :
# iteration run failed
yield IterationRunFailedEvent (
iteration_id = self . id ,
iteration_node_id = self . node_id ,
iteration_node_type = self . node_type ,
iteration_node_data = self . node_data ,
start_at = start_at ,
inputs = inputs ,
outputs = { " output " : jsonable_encoder ( outputs ) } ,
steps = len ( iterator_list_value ) ,
metadata = { " total_tokens " : graph_engine . graph_runtime_state . total_tokens } ,
error = event . error ,
)
yield RunCompletedEvent (
run_result = NodeRunResult (
status = WorkflowNodeExecutionStatus . FAILED ,
error = event . error ,
)
)
return
else :
event = cast ( InNodeEvent , event )
if self . node_data . is_parallel :
futures : list [ Future ] = [ ]
q = Queue ( )
thread_pool = GraphEngineThreadPool ( max_workers = self . node_data . parallel_nums , max_submit_count = 100 )
for index , item in enumerate ( iterator_list_value ) :
future : Future = thread_pool . submit (
self . _run_single_iter_parallel ,
current_app . _get_current_object ( ) ,
q ,
iterator_list_value ,
inputs ,
outputs ,
start_at ,
graph_engine ,
iteration_graph ,
index ,
item ,
)
future . add_done_callback ( thread_pool . task_done_callback )
futures . append ( future )
succeeded_count = 0
while True :
try :
event = q . get ( timeout = 1 )
if event is None :
break
if isinstance ( event , IterationRunNextEvent ) :
succeeded_count + = 1
if succeeded_count == len ( futures ) :
q . put ( None )
yield event
if isinstance ( event , RunCompletedEvent ) :
q . put ( None )
for f in futures :
if not f . done ( ) :
f . cancel ( )
yield event
if isinstance ( event , IterationRunFailedEvent ) :
q . put ( None )
yield event
except Empty :
continue
# append to iteration output variable list
current_iteration_output_variable = variable_pool . get ( self . node_data . output_selector )
if current_iteration_output_variable is None :
yield RunCompletedEvent (
run_result = NodeRunResult (
status = WorkflowNodeExecutionStatus . FAILED ,
error = f " Iteration output variable { self . node_data . output_selector } not found " ,
)
# wait all threads
wait ( futures )
else :
for _ in range ( len ( iterator_list_value ) ) :
yield from self . _run_single_iter (
iterator_list_value ,
variable_pool ,
inputs ,
outputs ,
start_at ,
graph_engine ,
iteration_graph ,
)
return
current_iteration_output = current_iteration_output_variable . to_object ( )
outputs . append ( current_iteration_output )
# remove all nodes outputs from variable pool
for node_id in iteration_graph . node_ids :
variable_pool . remove ( [ node_id ] )
# move to next iteration
current_index_variable = variable_pool . get ( [ self . node_id , " index " ] )
if not isinstance ( current_index_variable , IntegerSegment ) :
raise ValueError ( f " iteration { self . node_id } current index not found " )
next_index = current_index_variable . value + 1
variable_pool . add ( [ self . node_id , " index " ] , next_index )
if next_index < len ( iterator_list_value ) :
variable_pool . add ( [ self . node_id , " item " ] , iterator_list_value [ next_index ] )
yield IterationRunNextEvent (
iteration_id = self . id ,
iteration_node_id = self . node_id ,
iteration_node_type = self . node_type ,
iteration_node_data = self . node_data ,
index = next_index ,
pre_iteration_output = jsonable_encoder ( current_iteration_output ) ,
)
yield IterationRunSucceededEvent (
iteration_id = self . id ,
iteration_node_id = self . node_id ,
@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]):
}
return variable_mapping
def _handle_event_metadata (
self , event : BaseNodeEvent , iter_run_index : str , parallel_mode_run_id : str
) - > NodeRunStartedEvent | BaseNodeEvent :
"""
add iteration metadata to event .
"""
if not isinstance ( event , BaseNodeEvent ) :
return event
if self . node_data . is_parallel and isinstance ( event , NodeRunStartedEvent ) :
event . parallel_mode_run_id = parallel_mode_run_id
return event
if event . route_node_state . node_run_result :
metadata = event . route_node_state . node_run_result . metadata
if not metadata :
metadata = { }
if NodeRunMetadataKey . ITERATION_ID not in metadata :
metadata [ NodeRunMetadataKey . ITERATION_ID ] = self . node_id
if self . node_data . is_parallel :
metadata [ NodeRunMetadataKey . PARALLEL_MODE_RUN_ID ] = parallel_mode_run_id
else :
metadata [ NodeRunMetadataKey . ITERATION_INDEX ] = iter_run_index
event . route_node_state . node_run_result . metadata = metadata
return event
def _run_single_iter (
self ,
iterator_list_value : list [ str ] ,
variable_pool : VariablePool ,
inputs : dict [ str , list ] ,
outputs : list ,
start_at : datetime ,
graph_engine : " GraphEngine " ,
iteration_graph : Graph ,
parallel_mode_run_id : Optional [ str ] = None ,
) - > Generator [ NodeEvent | InNodeEvent , None , None ] :
"""
run single iteration
"""
try :
rst = graph_engine . run ( )
# get current iteration index
current_index = variable_pool . get ( [ self . node_id , " index " ] ) . value
next_index = int ( current_index ) + 1
if current_index is None :
raise ValueError ( f " iteration { self . node_id } current index not found " )
for event in rst :
if isinstance ( event , ( BaseNodeEvent | BaseParallelBranchEvent ) ) and not event . in_iteration_id :
event . in_iteration_id = self . node_id
if (
isinstance ( event , BaseNodeEvent )
and event . node_type == NodeType . ITERATION_START
and not isinstance ( event , NodeRunStreamChunkEvent )
) :
continue
if isinstance ( event , NodeRunSucceededEvent ) :
yield self . _handle_event_metadata ( event , current_index , parallel_mode_run_id )
elif isinstance ( event , BaseGraphEvent ) :
if isinstance ( event , GraphRunFailedEvent ) :
# iteration run failed
if self . node_data . is_parallel :
yield IterationRunFailedEvent (
iteration_id = self . id ,
iteration_node_id = self . node_id ,
iteration_node_type = self . node_type ,
iteration_node_data = self . node_data ,
parallel_mode_run_id = parallel_mode_run_id ,
start_at = start_at ,
inputs = inputs ,
outputs = { " output " : jsonable_encoder ( outputs ) } ,
steps = len ( iterator_list_value ) ,
metadata = { " total_tokens " : graph_engine . graph_runtime_state . total_tokens } ,
error = event . error ,
)
else :
yield IterationRunFailedEvent (
iteration_id = self . id ,
iteration_node_id = self . node_id ,
iteration_node_type = self . node_type ,
iteration_node_data = self . node_data ,
start_at = start_at ,
inputs = inputs ,
outputs = { " output " : jsonable_encoder ( outputs ) } ,
steps = len ( iterator_list_value ) ,
metadata = { " total_tokens " : graph_engine . graph_runtime_state . total_tokens } ,
error = event . error ,
)
yield RunCompletedEvent (
run_result = NodeRunResult (
status = WorkflowNodeExecutionStatus . FAILED ,
error = event . error ,
)
)
return
else :
event = cast ( InNodeEvent , event )
metadata_event = self . _handle_event_metadata ( event , current_index , parallel_mode_run_id )
if isinstance ( event , NodeRunFailedEvent ) :
if self . node_data . error_handle_mode == ErrorHandleMode . CONTINUE_ON_ERROR :
yield NodeInIterationFailedEvent (
* * metadata_event . model_dump ( ) ,
)
outputs . insert ( current_index , None )
variable_pool . add ( [ self . node_id , " index " ] , next_index )
if next_index < len ( iterator_list_value ) :
variable_pool . add ( [ self . node_id , " item " ] , iterator_list_value [ next_index ] )
yield IterationRunNextEvent (
iteration_id = self . id ,
iteration_node_id = self . node_id ,
iteration_node_type = self . node_type ,
iteration_node_data = self . node_data ,
index = next_index ,
parallel_mode_run_id = parallel_mode_run_id ,
pre_iteration_output = None ,
)
return
elif self . node_data . error_handle_mode == ErrorHandleMode . REMOVE_ABNORMAL_OUTPUT :
yield NodeInIterationFailedEvent (
* * metadata_event . model_dump ( ) ,
)
variable_pool . add ( [ self . node_id , " index " ] , next_index )
if next_index < len ( iterator_list_value ) :
variable_pool . add ( [ self . node_id , " item " ] , iterator_list_value [ next_index ] )
yield IterationRunNextEvent (
iteration_id = self . id ,
iteration_node_id = self . node_id ,
iteration_node_type = self . node_type ,
iteration_node_data = self . node_data ,
index = next_index ,
parallel_mode_run_id = parallel_mode_run_id ,
pre_iteration_output = None ,
)
return
elif self . node_data . error_handle_mode == ErrorHandleMode . TERMINATED :
yield IterationRunFailedEvent (
iteration_id = self . id ,
iteration_node_id = self . node_id ,
iteration_node_type = self . node_type ,
iteration_node_data = self . node_data ,
start_at = start_at ,
inputs = inputs ,
outputs = { " output " : None } ,
steps = len ( iterator_list_value ) ,
metadata = { " total_tokens " : graph_engine . graph_runtime_state . total_tokens } ,
error = event . error ,
)
yield metadata_event
current_iteration_output = variable_pool . get ( self . node_data . output_selector ) . value
outputs . insert ( current_index , current_iteration_output )
# remove all nodes outputs from variable pool
for node_id in iteration_graph . node_ids :
variable_pool . remove ( [ node_id ] )
# move to next iteration
variable_pool . add ( [ self . node_id , " index " ] , next_index )
if next_index < len ( iterator_list_value ) :
variable_pool . add ( [ self . node_id , " item " ] , iterator_list_value [ next_index ] )
yield IterationRunNextEvent (
iteration_id = self . id ,
iteration_node_id = self . node_id ,
iteration_node_type = self . node_type ,
iteration_node_data = self . node_data ,
index = next_index ,
parallel_mode_run_id = parallel_mode_run_id ,
pre_iteration_output = jsonable_encoder ( current_iteration_output ) if current_iteration_output else None ,
)
except Exception as e :
logger . exception ( f " Iteration run failed: { str ( e ) } " )
yield IterationRunFailedEvent (
iteration_id = self . id ,
iteration_node_id = self . node_id ,
iteration_node_type = self . node_type ,
iteration_node_data = self . node_data ,
start_at = start_at ,
inputs = inputs ,
outputs = { " output " : None } ,
steps = len ( iterator_list_value ) ,
metadata = { " total_tokens " : graph_engine . graph_runtime_state . total_tokens } ,
error = str ( e ) ,
)
yield RunCompletedEvent (
run_result = NodeRunResult (
status = WorkflowNodeExecutionStatus . FAILED ,
error = str ( e ) ,
)
)
def _run_single_iter_parallel (
self ,
flask_app : Flask ,
q : Queue ,
iterator_list_value : list [ str ] ,
inputs : dict [ str , list ] ,
outputs : list ,
start_at : datetime ,
graph_engine : " GraphEngine " ,
iteration_graph : Graph ,
index : int ,
item : Any ,
) - > Generator [ NodeEvent | InNodeEvent , None , None ] :
"""
run single iteration in parallel mode
"""
with flask_app . app_context ( ) :
parallel_mode_run_id = uuid . uuid4 ( ) . hex
graph_engine_copy = graph_engine . create_copy ( )
variable_pool_copy = graph_engine_copy . graph_runtime_state . variable_pool
variable_pool_copy . add ( [ self . node_id , " index " ] , index )
variable_pool_copy . add ( [ self . node_id , " item " ] , item )
for event in self . _run_single_iter (
iterator_list_value = iterator_list_value ,
variable_pool = variable_pool_copy ,
inputs = inputs ,
outputs = outputs ,
start_at = start_at ,
graph_engine = graph_engine_copy ,
iteration_graph = iteration_graph ,
parallel_mode_run_id = parallel_mode_run_id ,
) :
q . put ( event )