From b13ae784a452a43bb8098b02319f69383afd1000 Mon Sep 17 00:00:00 2001 From: xuzijie1995 <18852951350@163.com> Date: Mon, 21 Jul 2025 11:18:51 +0800 Subject: [PATCH 1/2] fix(workflow): Resolve streaming failure on conditional join points Improves the robustness of the workflow engine's streaming output by fixing two core issues that caused streaming to fail in complex topologies where multiple conditional branches merge. **1. Corrected Runtime State Management ("Pruning"):** The primary bug was located in the `_remove_unreachable_nodes` method. Its aggressive recursive "pruning" algorithm incorrectly removed shared downstream nodes (including LLM and Answer) when handling conditional branches that led to a join point. This prematurely emptied the `rest_node_ids` list, causing the stream processor to fail its initial state check. The fix replaces the recursive logic with a more conservative, non-recursive approach that only prunes the immediate first node of an unreachable branch. This ensures the integrity of the `rest_node_ids` list throughout the workflow execution. **2. Improved Static Dependency Analysis:** A secondary, underlying issue was found in the static dependency analysis (`_recursive_fetch_answer_dependencies`). It incorrectly identified all upstream, mutually exclusive `If/Else` nodes as parallel dependencies of the Answer node. The fix enhances this analysis by adding "join point awareness". The upward trace now stops when it encounters a node with more than one incoming edge, correctly identifying the join point itself as the dependency rather than its upstream branches. Together, these changes ensure that streaming output remains reliable and predictable, even in complex workflows with reusable, multi-input nodes. --- .../workflow/graph_engine/graph_engine.py | 4 +- .../answer/answer_stream_generate_router.py | 10 ++ .../nodes/answer/answer_stream_processor.py | 99 +++++++++++++------ .../nodes/answer/base_stream_processor.py | 15 ++- 4 files changed, 93 insertions(+), 35 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b315129763..daf6efb952 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -152,7 +152,9 @@ class GraphEngine: try: if self.init_params.workflow_type == WorkflowType.CHAT: stream_processor = AnswerStreamProcessor( - graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + graph=self.graph, + variable_pool=self.graph_runtime_state.variable_pool, + node_run_state=self.graph_runtime_state.node_run_state, ) else: stream_processor = EndStreamProcessor( diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 1d9c3e9b96..7f9ebf756f 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -90,6 +90,9 @@ class AnswerStreamGeneratorRouter: :return: """ node_data = AnswerNodeData(**config.get("data", {})) + # Trim whitespace from the answer template to prevent parsing issues with leading/trailing spaces. + if node_data.answer: + node_data.answer = node_data.answer.strip() return cls.extract_generate_route_from_node_data(node_data) @classmethod @@ -145,6 +148,13 @@ class AnswerStreamGeneratorRouter: :return: """ reverse_edges = reverse_edge_mapping.get(current_node_id, []) + + # If the current node has more than one incoming edge, it's a join point. + # We should add it as a dependency and stop tracing up further. + if len(reverse_edges) > 1: + answer_dependencies[answer_node_id].append(current_node_id) + return + for edge in reverse_edges: source_node_id = edge.source_node_id if source_node_id not in node_id_config_mapping: diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 97666fad05..d7b6d3d671 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -11,6 +11,7 @@ from core.workflow.graph_engine.entities.event import ( NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk @@ -18,8 +19,8 @@ logger = logging.getLogger(__name__) class AnswerStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: - super().__init__(graph, variable_pool) + def __init__(self, graph: Graph, variable_pool: VariablePool, node_run_state: RuntimeRouteState) -> None: + super().__init__(graph, variable_pool, node_run_state) self.generate_routes = graph.answer_stream_generate_routes self.route_position = {} for answer_node_id in self.generate_routes.answer_generate_route: @@ -73,6 +74,64 @@ class AnswerStreamProcessor(StreamProcessor): self.rest_node_ids = self.graph.node_ids.copy() self.current_stream_chunk_generating_node_ids = {} + def _is_dynamic_dependencies_met(self, start_node_id: str) -> bool: + """ + Check if all dynamic dependencies are met for a given node by traversing backwards. + This method is based on the runtime state of the graph. + """ + # Use a queue for BFS and a set to track visited nodes to prevent cycles + queue = [start_node_id] + visited = {start_node_id} + + while queue: + current_node_id = queue.pop(0) + + # Get the edges leading to the current node + parent_edges = self.graph.reverse_edge_mapping.get(current_node_id, []) + if not parent_edges: + continue + + for edge in parent_edges: + parent_node_id = edge.source_node_id + + if parent_node_id in visited: + continue + + visited.add(parent_node_id) + + # Find the latest execution state of the parent node in the current run + parent_node_run_state = None + for state in self.node_run_state.node_state_mapping.values(): + if state.node_id == parent_node_id: + parent_node_run_state = state + break # Assume the last found state is the latest for simplicity + + if not parent_node_run_state or parent_node_run_state.status == RouteNodeState.Status.RUNNING: + return False + + if parent_node_run_state.status in [RouteNodeState.Status.FAILED, RouteNodeState.Status.EXCEPTION]: + return False + + # If the parent is a branch node, check if the executed branch leads to the current node + parent_node_config = self.graph.node_id_config_mapping.get(parent_node_id, {}) + parent_node_type = parent_node_config.get('data', {}).get('type') + + is_branch_node = parent_node_type in ['if-else', 'question-classifier'] # Example branch types + + if is_branch_node: + run_result = parent_node_run_state.node_run_result + chosen_handle = run_result.edge_source_handle if run_result else None + required_handle = edge.run_condition.branch_identify if edge.run_condition else None + + # If the chosen branch does not match the path we are traversing, this dependency path is irrelevant + if chosen_handle and required_handle and chosen_handle != required_handle: + continue # This path was not taken, so it's not a dependency + + # If all checks pass, add the parent to the queue to continue traversing up + queue.append(parent_node_id) + + return True + def _generate_stream_outputs_when_node_finished( self, event: NodeRunSucceededEvent ) -> Generator[GraphEngineEvent, None, None]: @@ -87,7 +146,7 @@ class AnswerStreamProcessor(StreamProcessor): answer_node_id not in self.rest_node_ids or not all( dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + for dep_id in self.generate_routes.answer_dependencies.get(answer_node_id, []) ) ): continue @@ -156,33 +215,13 @@ class AnswerStreamProcessor(StreamProcessor): for answer_node_id, route_position in self.route_position.items(): if answer_node_id not in self.rest_node_ids: continue - # Remove current node id from answer dependencies to support stream output if it is a success branch - answer_dependencies = self.generate_routes.answer_dependencies - edge_mapping = self.graph.edge_mapping.get(event.node_id) - success_edge = ( - next( - ( - edge - for edge in edge_mapping - if edge.run_condition - and edge.run_condition.type == "branch_identify" - and edge.run_condition.branch_identify == "success-branch" - ), - None, - ) - if edge_mapping - else None - ) - if ( - event.node_id in answer_dependencies[answer_node_id] - and success_edge - and success_edge.target_node_id == answer_node_id - ): - answer_dependencies[answer_node_id].remove(event.node_id) - answer_dependencies_ids = answer_dependencies.get(answer_node_id, []) - # all depends on answer node id not in rest node ids - if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids): - if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): + + # New dynamic dependency check + source_node_id_for_check = event.from_variable_selector[0] + all_deps_finished = self._is_dynamic_dependencies_met(start_node_id=source_node_id_for_check) + + if all_deps_finished: + if route_position >= len(self.generate_routes.answer_generate_route.get(answer_node_id, [])): continue route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 09d5464d7a..08c356758c 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -6,14 +6,16 @@ from typing import Optional from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState logger = logging.getLogger(__name__) class StreamProcessor(ABC): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + def __init__(self, graph: Graph, variable_pool: VariablePool, node_run_state: RuntimeRouteState) -> None: self.graph = graph self.variable_pool = variable_pool + self.node_run_state = node_run_state self.rest_node_ids = graph.node_ids.copy() @abstractmethod @@ -68,9 +70,14 @@ class StreamProcessor(ABC): ): continue unreachable_first_node_ids.append(edge.target_node_id) - unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids)) - for node_id in unreachable_first_node_ids: - self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) + + # Instead of recursively removing the entire unreachable branch, + # which can cause issues with complex join points, + # we will only remove the immediate first node of the unreachable branch. + # This prevents the removal logic from incorrectly pruning shared paths downstream. + for node_id in list(set(unreachable_first_node_ids) - set(reachable_node_ids)): + if node_id in self.rest_node_ids: + self.rest_node_ids.remove(node_id) def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: if node_id not in self.rest_node_ids: From db91643915a226da3d4a31aebb1433906f050eef Mon Sep 17 00:00:00 2001 From: xuzijie1995 <18852951350@163.com> Date: Tue, 22 Jul 2025 10:42:03 +0800 Subject: [PATCH 2/2] refactor(workflow): Rearchitect stream dependency logic for complex graphs This commit addresses a critical issue where streaming output would fail in workflows with complex topologies, particularly those involving multiple conditional branches (if/else) that converge on a common node before the LLM and Answer nodes. The root cause was twofold: 1. A bug in the branch pruning logic () that would incorrectly remove shared downstream nodes, leading to a premature emptying of the list. 2. A flawed static dependency analysis () that could not correctly resolve dependencies for nodes that were part of multiple, mutually exclusive execution paths. This refactor introduces a new, robust architecture for streaming dependency management based on the principle of "Static Pre-pruning + Dynamic Adjudication": - **Fix**: The branch pruning logic in is now non-recursive and conservative. It only prunes the immediate first node of an unreachable branch, preserving the integrity of shared downstream paths and join points. - **Refactor**: The old static dependency analysis has been completely removed. This includes deleting the attribute from the entity and deleting the associated recursive dependency fetching methods (, ). - **Feat**: A new method, , has been implemented in . This method performs a real-time, backward traversal of the graph from the streaming node, querying the runtime execution state () to dynamically validate if the *actual* dependency path has been successfully completed. This ensures that streaming decisions are based on the ground truth of the current execution, not a flawed static prediction. - **Doc**: Added comprehensive docstrings and comments to the modified components to explain the new architecture and the rationale behind the changes. --- .../answer/answer_stream_generate_router.py | 101 ++---------------- .../nodes/answer/answer_stream_processor.py | 27 +++-- .../nodes/answer/base_stream_processor.py | 10 ++ api/core/workflow/nodes/answer/entities.py | 3 - 4 files changed, 36 insertions(+), 105 deletions(-) diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 7f9ebf756f..bf3a72f56d 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import ( TextGenerateRouteChunk, VarGenerateRouteChunk, ) -from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -18,8 +18,10 @@ class AnswerStreamGeneratorRouter: reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] ) -> AnswerStreamGenerateRoute: """ - Get stream generate routes. - :return: + Initializes the stream generation routes for all Answer nodes in the workflow. + This method performs a static analysis of the graph to parse the answer templates. + The old logic for pre-calculating static dependencies has been deprecated and removed, + as the decision logic is now handled dynamically at runtime by the AnswerStreamProcessor. """ # parse stream output node value selectors of answer nodes answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} @@ -31,16 +33,8 @@ class AnswerStreamGeneratorRouter: generate_route = cls._extract_generate_route_selectors(node_config) answer_generate_route[answer_node_id] = generate_route - # fetch answer dependencies - answer_node_ids = list(answer_generate_route.keys()) - answer_dependencies = cls._fetch_answers_dependencies( - answer_node_ids=answer_node_ids, - reverse_edge_mapping=reverse_edge_mapping, - node_id_config_mapping=node_id_config_mapping, - ) - return AnswerStreamGenerateRoute( - answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies + answer_generate_route=answer_generate_route ) @classmethod @@ -99,86 +93,3 @@ class AnswerStreamGeneratorRouter: def _is_variable(cls, part, variable_keys): cleaned_part = part.replace("{{", "").replace("}}", "") return part.startswith("{{") and cleaned_part in variable_keys - - @classmethod - def _fetch_answers_dependencies( - cls, - answer_node_ids: list[str], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_id_config_mapping: dict[str, dict], - ) -> dict[str, list[str]]: - """ - Fetch answer dependencies - :param answer_node_ids: answer node ids - :param reverse_edge_mapping: reverse edge mapping - :param node_id_config_mapping: node id config mapping - :return: - """ - answer_dependencies: dict[str, list[str]] = {} - for answer_node_id in answer_node_ids: - if answer_dependencies.get(answer_node_id) is None: - answer_dependencies[answer_node_id] = [] - - cls._recursive_fetch_answer_dependencies( - current_node_id=answer_node_id, - answer_node_id=answer_node_id, - node_id_config_mapping=node_id_config_mapping, - reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies, - ) - - return answer_dependencies - - @classmethod - def _recursive_fetch_answer_dependencies( - cls, - current_node_id: str, - answer_node_id: str, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - answer_dependencies: dict[str, list[str]], - ) -> None: - """ - Recursive fetch answer dependencies - :param current_node_id: current node id - :param answer_node_id: answer node id - :param node_id_config_mapping: node id config mapping - :param reverse_edge_mapping: reverse edge mapping - :param answer_dependencies: answer dependencies - :return: - """ - reverse_edges = reverse_edge_mapping.get(current_node_id, []) - - # If the current node has more than one incoming edge, it's a join point. - # We should add it as a dependency and stop tracing up further. - if len(reverse_edges) > 1: - answer_dependencies[answer_node_id].append(current_node_id) - return - - for edge in reverse_edges: - source_node_id = edge.source_node_id - if source_node_id not in node_id_config_mapping: - continue - source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - source_node_data = node_id_config_mapping[source_node_id].get("data", {}) - if ( - source_node_type - in { - NodeType.ANSWER, - NodeType.IF_ELSE, - NodeType.QUESTION_CLASSIFIER, - NodeType.ITERATION, - NodeType.LOOP, - NodeType.VARIABLE_ASSIGNER, - } - or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH - ): - answer_dependencies[answer_node_id].append(source_node_id) - else: - cls._recursive_fetch_answer_dependencies( - current_node_id=source_node_id, - answer_node_id=answer_node_id, - node_id_config_mapping=node_id_config_mapping, - reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies, - ) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index d7b6d3d671..51c08a0543 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -76,8 +76,24 @@ class AnswerStreamProcessor(StreamProcessor): def _is_dynamic_dependencies_met(self, start_node_id: str) -> bool: """ - Check if all dynamic dependencies are met for a given node by traversing backwards. - This method is based on the runtime state of the graph. + Performs a dynamic, runtime dependency check by traversing backwards from a given start_node_id. + + This method is the core of the new streaming architecture. Instead of relying on a pre-calculated, + static dependency map, it validates the actual execution path at the moment a stream event is received. + It queries the runtime state of the graph ('the logbook') to ensure that a valid, uninterrupted, + and logically sound path exists from the start_node_id all the way back to the graph's entry point. + + The traversal logic handles: + - Basic node completion states (SUCCEEDED, FAILED, RUNNING). + - Complex branch nodes (If/Else), by checking which branch was actually taken during the run. + Paths from branches that were not taken are considered irrelevant ("parallel universes") and ignored. + + This approach correctly handles complex topologies with join points (nodes with multiple inputs), + ensuring that streaming is only permitted when the true, logical dependency chain for the *current run* + has been successfully completed. + + :param start_node_id: The node ID from which to begin the backward traversal (e.g., the LLM node). + :return: True if all dependencies on the active path are met, False otherwise. """ # Use a queue for BFS and a set to track visited nodes to prevent cycles queue = [start_node_id] @@ -144,10 +160,7 @@ class AnswerStreamProcessor(StreamProcessor): # all depends on answer node id not in rest node ids if event.route_node_state.node_id != answer_node_id and ( answer_node_id not in self.rest_node_ids - or not all( - dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies.get(answer_node_id, []) - ) + or not self._is_dynamic_dependencies_met(answer_node_id) # Using dynamic check for final output as well ): continue @@ -216,7 +229,7 @@ class AnswerStreamProcessor(StreamProcessor): if answer_node_id not in self.rest_node_ids: continue - # New dynamic dependency check + # New dynamic dependency check, replacing the old static dependency list. source_node_id_for_check = event.from_variable_selector[0] all_deps_finished = self._is_dynamic_dependencies_met(start_node_id=source_node_id_for_check) diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 08c356758c..278b26134a 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -23,6 +23,16 @@ class StreamProcessor(ABC): raise NotImplementedError def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None: + """ + Prunes unreachable branches from the `rest_node_ids` list after a branch node has executed. + + This method implements a conservative, non-recursive pruning strategy to prevent a critical bug + where the pruning process would incorrectly "spread" across join points (nodes with multiple inputs) + and erroneously remove shared downstream nodes that should have been preserved. + + By only removing the immediate first node of each determined unreachable branch, we ensure that + the integrity of shared paths in complex graph topologies is maintained. + """ finished_node_id = event.route_node_state.node_id if finished_node_id not in self.rest_node_ids: return diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index a05cc44c99..0c939fece8 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -57,9 +57,6 @@ class AnswerStreamGenerateRoute(BaseModel): AnswerStreamGenerateRoute entity """ - answer_dependencies: dict[str, list[str]] = Field( - ..., description="answer dependencies (answer node id -> dependent answer node ids)" - ) answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( ..., description="answer generate route (answer node id -> generate route chunks)" )