@ -11,6 +11,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent ,
NodeRunSucceededEvent ,
)
)
from core . workflow . graph_engine . entities . graph import Graph
from core . workflow . graph_engine . entities . graph import Graph
from core . workflow . graph_engine . entities . runtime_route_state import RouteNodeState , RuntimeRouteState
from core . workflow . nodes . answer . base_stream_processor import StreamProcessor
from core . workflow . nodes . answer . base_stream_processor import StreamProcessor
from core . workflow . nodes . answer . entities import GenerateRouteChunk , TextGenerateRouteChunk , VarGenerateRouteChunk
from core . workflow . nodes . answer . entities import GenerateRouteChunk , TextGenerateRouteChunk , VarGenerateRouteChunk
@ -18,8 +19,8 @@ logger = logging.getLogger(__name__)
class AnswerStreamProcessor ( StreamProcessor ) :
class AnswerStreamProcessor ( StreamProcessor ) :
def __init__ ( self , graph : Graph , variable_pool : VariablePool ) - > None :
def __init__ ( self , graph : Graph , variable_pool : VariablePool , node_run_state : RuntimeRouteState ) - > None :
super ( ) . __init__ ( graph , variable_pool )
super ( ) . __init__ ( graph , variable_pool , node_run_state )
self . generate_routes = graph . answer_stream_generate_routes
self . generate_routes = graph . answer_stream_generate_routes
self . route_position = { }
self . route_position = { }
for answer_node_id in self . generate_routes . answer_generate_route :
for answer_node_id in self . generate_routes . answer_generate_route :
@ -73,6 +74,64 @@ class AnswerStreamProcessor(StreamProcessor):
self . rest_node_ids = self . graph . node_ids . copy ( )
self . rest_node_ids = self . graph . node_ids . copy ( )
self . current_stream_chunk_generating_node_ids = { }
self . current_stream_chunk_generating_node_ids = { }
def _is_dynamic_dependencies_met ( self , start_node_id : str ) - > bool :
"""
Check if all dynamic dependencies are met for a given node by traversing backwards .
This method is based on the runtime state of the graph .
"""
# Use a queue for BFS and a set to track visited nodes to prevent cycles
queue = [ start_node_id ]
visited = { start_node_id }
while queue :
current_node_id = queue . pop ( 0 )
# Get the edges leading to the current node
parent_edges = self . graph . reverse_edge_mapping . get ( current_node_id , [ ] )
if not parent_edges :
continue
for edge in parent_edges :
parent_node_id = edge . source_node_id
if parent_node_id in visited :
continue
visited . add ( parent_node_id )
# Find the latest execution state of the parent node in the current run
parent_node_run_state = None
for state in self . node_run_state . node_state_mapping . values ( ) :
if state . node_id == parent_node_id :
parent_node_run_state = state
break # Assume the last found state is the latest for simplicity
if not parent_node_run_state or parent_node_run_state . status == RouteNodeState . Status . RUNNING :
return False
if parent_node_run_state . status in [ RouteNodeState . Status . FAILED , RouteNodeState . Status . EXCEPTION ] :
return False
# If the parent is a branch node, check if the executed branch leads to the current node
parent_node_config = self . graph . node_id_config_mapping . get ( parent_node_id , { } )
parent_node_type = parent_node_config . get ( ' data ' , { } ) . get ( ' type ' )
is_branch_node = parent_node_type in [ ' if-else ' , ' question-classifier ' ] # Example branch types
if is_branch_node :
run_result = parent_node_run_state . node_run_result
chosen_handle = run_result . edge_source_handle if run_result else None
required_handle = edge . run_condition . branch_identify if edge . run_condition else None
# If the chosen branch does not match the path we are traversing, this dependency path is irrelevant
if chosen_handle and required_handle and chosen_handle != required_handle :
continue # This path was not taken, so it's not a dependency
# If all checks pass, add the parent to the queue to continue traversing up
queue . append ( parent_node_id )
return True
def _generate_stream_outputs_when_node_finished (
def _generate_stream_outputs_when_node_finished (
self , event : NodeRunSucceededEvent
self , event : NodeRunSucceededEvent
) - > Generator [ GraphEngineEvent , None , None ] :
) - > Generator [ GraphEngineEvent , None , None ] :
@ -87,7 +146,7 @@ class AnswerStreamProcessor(StreamProcessor):
answer_node_id not in self . rest_node_ids
answer_node_id not in self . rest_node_ids
or not all (
or not all (
dep_id not in self . rest_node_ids
dep_id not in self . rest_node_ids
for dep_id in self . generate_routes . answer_dependencies [ answer_node_id ]
for dep_id in self . generate_routes . answer_dependencies . get ( answer_node_id , [ ] )
)
)
) :
) :
continue
continue
@ -156,33 +215,13 @@ class AnswerStreamProcessor(StreamProcessor):
for answer_node_id , route_position in self . route_position . items ( ) :
for answer_node_id , route_position in self . route_position . items ( ) :
if answer_node_id not in self . rest_node_ids :
if answer_node_id not in self . rest_node_ids :
continue
continue
# Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self . generate_routes . answer_dependencies
# New dynamic dependency check
edge_mapping = self . graph . edge_mapping . get ( event . node_id )
source_node_id_for_check = event . from_variable_selector [ 0 ]
success_edge = (
all_deps_finished = self . _is_dynamic_dependencies_met ( start_node_id = source_node_id_for_check )
next (
(
if all_deps_finished :
edge
if route_position > = len ( self . generate_routes . answer_generate_route . get ( answer_node_id , [ ] ) ) :
for edge in edge_mapping
if edge . run_condition
and edge . run_condition . type == " branch_identify "
and edge . run_condition . branch_identify == " success-branch "
) ,
None ,
)
if edge_mapping
else None
)
if (
event . node_id in answer_dependencies [ answer_node_id ]
and success_edge
and success_edge . target_node_id == answer_node_id
) :
answer_dependencies [ answer_node_id ] . remove ( event . node_id )
answer_dependencies_ids = answer_dependencies . get ( answer_node_id , [ ] )
# all depends on answer node id not in rest node ids
if all ( dep_id not in self . rest_node_ids for dep_id in answer_dependencies_ids ) :
if route_position > = len ( self . generate_routes . answer_generate_route [ answer_node_id ] ) :
continue
continue
route_chunk = self . generate_routes . answer_generate_route [ answer_node_id ] [ route_position ]
route_chunk = self . generate_routes . answer_generate_route [ answer_node_id ] [ route_position ]