fix(workflow): Resolve streaming failure on conditional join points

Improves the robustness of the workflow engine's streaming output by fixing two core issues that caused streaming to fail in complex topologies where multiple conditional branches merge.

**1. Corrected Runtime State Management ("Pruning"):**

The primary bug was located in the `_remove_unreachable_nodes` method. Its aggressive recursive "pruning" algorithm incorrectly removed shared downstream nodes (including LLM and Answer) when handling conditional branches that led to a join point. This prematurely emptied the `rest_node_ids` list, causing the stream processor to fail its initial state check.

The fix replaces the recursive logic with a more conservative, non-recursive approach that only prunes the immediate first node of an unreachable branch. This ensures the integrity of the `rest_node_ids` list throughout the workflow execution.

**2. Improved Static Dependency Analysis:**

A secondary, underlying issue was found in the static dependency analysis (`_recursive_fetch_answer_dependencies`). It incorrectly identified all upstream, mutually exclusive `If/Else` nodes as parallel dependencies of the Answer node.

The fix enhances this analysis by adding "join point awareness". The upward trace now stops when it encounters a node with more than one incoming edge, correctly identifying the join point itself as the dependency rather than its upstream branches.

Together, these changes ensure that streaming output remains reliable and predictable, even in complex workflows with reusable, multi-input nodes.
pull/22771/head
xuzijie1995 7 months ago
parent 383a79772c
commit b13ae784a4

@ -152,7 +152,9 @@ class GraphEngine:
try:
if self.init_params.workflow_type == WorkflowType.CHAT:
stream_processor = AnswerStreamProcessor(
graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool,
node_run_state=self.graph_runtime_state.node_run_state,
)
else:
stream_processor = EndStreamProcessor(

@ -90,6 +90,9 @@ class AnswerStreamGeneratorRouter:
:return:
"""
node_data = AnswerNodeData(**config.get("data", {}))
# Trim whitespace from the answer template to prevent parsing issues with leading/trailing spaces.
if node_data.answer:
node_data.answer = node_data.answer.strip()
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
@ -145,6 +148,13 @@ class AnswerStreamGeneratorRouter:
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
# If the current node has more than one incoming edge, it's a join point.
# We should add it as a dependency and stop tracing up further.
if len(reverse_edges) > 1:
answer_dependencies[answer_node_id].append(current_node_id)
return
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:

@ -11,6 +11,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
@ -18,8 +19,8 @@ logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
def __init__(self, graph: Graph, variable_pool: VariablePool, node_run_state: RuntimeRouteState) -> None:
super().__init__(graph, variable_pool, node_run_state)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
for answer_node_id in self.generate_routes.answer_generate_route:
@ -73,6 +74,64 @@ class AnswerStreamProcessor(StreamProcessor):
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _is_dynamic_dependencies_met(self, start_node_id: str) -> bool:
"""
Check if all dynamic dependencies are met for a given node by traversing backwards.
This method is based on the runtime state of the graph.
"""
# Use a queue for BFS and a set to track visited nodes to prevent cycles
queue = [start_node_id]
visited = {start_node_id}
while queue:
current_node_id = queue.pop(0)
# Get the edges leading to the current node
parent_edges = self.graph.reverse_edge_mapping.get(current_node_id, [])
if not parent_edges:
continue
for edge in parent_edges:
parent_node_id = edge.source_node_id
if parent_node_id in visited:
continue
visited.add(parent_node_id)
# Find the latest execution state of the parent node in the current run
parent_node_run_state = None
for state in self.node_run_state.node_state_mapping.values():
if state.node_id == parent_node_id:
parent_node_run_state = state
break # Assume the last found state is the latest for simplicity
if not parent_node_run_state or parent_node_run_state.status == RouteNodeState.Status.RUNNING:
return False
if parent_node_run_state.status in [RouteNodeState.Status.FAILED, RouteNodeState.Status.EXCEPTION]:
return False
# If the parent is a branch node, check if the executed branch leads to the current node
parent_node_config = self.graph.node_id_config_mapping.get(parent_node_id, {})
parent_node_type = parent_node_config.get('data', {}).get('type')
is_branch_node = parent_node_type in ['if-else', 'question-classifier'] # Example branch types
if is_branch_node:
run_result = parent_node_run_state.node_run_result
chosen_handle = run_result.edge_source_handle if run_result else None
required_handle = edge.run_condition.branch_identify if edge.run_condition else None
# If the chosen branch does not match the path we are traversing, this dependency path is irrelevant
if chosen_handle and required_handle and chosen_handle != required_handle:
continue # This path was not taken, so it's not a dependency
# If all checks pass, add the parent to the queue to continue traversing up
queue.append(parent_node_id)
return True
def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
@ -87,7 +146,7 @@ class AnswerStreamProcessor(StreamProcessor):
answer_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
for dep_id in self.generate_routes.answer_dependencies.get(answer_node_id, [])
)
):
continue
@ -156,33 +215,13 @@ class AnswerStreamProcessor(StreamProcessor):
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies
edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids):
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
# New dynamic dependency check
source_node_id_for_check = event.from_variable_selector[0]
all_deps_finished = self._is_dynamic_dependencies_met(start_node_id=source_node_id_for_check)
if all_deps_finished:
if route_position >= len(self.generate_routes.answer_generate_route.get(answer_node_id, [])):
continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]

@ -6,14 +6,16 @@ from typing import Optional
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
logger = logging.getLogger(__name__)
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
def __init__(self, graph: Graph, variable_pool: VariablePool, node_run_state: RuntimeRouteState) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.node_run_state = node_run_state
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
@ -68,9 +70,14 @@ class StreamProcessor(ABC):
):
continue
unreachable_first_node_ids.append(edge.target_node_id)
unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids))
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
# Instead of recursively removing the entire unreachable branch,
# which can cause issues with complex join points,
# we will only remove the immediate first node of the unreachable branch.
# This prevents the removal logic from incorrectly pruning shared paths downstream.
for node_id in list(set(unreachable_first_node_ids) - set(reachable_node_ids)):
if node_id in self.rest_node_ids:
self.rest_node_ids.remove(node_id)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
if node_id not in self.rest_node_ids:

Loading…
Cancel
Save