From db91643915a226da3d4a31aebb1433906f050eef Mon Sep 17 00:00:00 2001 From: xuzijie1995 <18852951350@163.com> Date: Tue, 22 Jul 2025 10:42:03 +0800 Subject: [PATCH] refactor(workflow): Rearchitect stream dependency logic for complex graphs This commit addresses a critical issue where streaming output would fail in workflows with complex topologies, particularly those involving multiple conditional branches (if/else) that converge on a common node before the LLM and Answer nodes. The root cause was twofold: 1. A bug in the branch pruning logic () that would incorrectly remove shared downstream nodes, leading to a premature emptying of the list. 2. A flawed static dependency analysis () that could not correctly resolve dependencies for nodes that were part of multiple, mutually exclusive execution paths. This refactor introduces a new, robust architecture for streaming dependency management based on the principle of "Static Pre-pruning + Dynamic Adjudication": - **Fix**: The branch pruning logic in is now non-recursive and conservative. It only prunes the immediate first node of an unreachable branch, preserving the integrity of shared downstream paths and join points. - **Refactor**: The old static dependency analysis has been completely removed. This includes deleting the attribute from the entity and deleting the associated recursive dependency fetching methods (, ). - **Feat**: A new method, , has been implemented in . This method performs a real-time, backward traversal of the graph from the streaming node, querying the runtime execution state () to dynamically validate if the *actual* dependency path has been successfully completed. This ensures that streaming decisions are based on the ground truth of the current execution, not a flawed static prediction. - **Doc**: Added comprehensive docstrings and comments to the modified components to explain the new architecture and the rationale behind the changes. --- .../answer/answer_stream_generate_router.py | 101 ++---------------- .../nodes/answer/answer_stream_processor.py | 27 +++-- .../nodes/answer/base_stream_processor.py | 10 ++ api/core/workflow/nodes/answer/entities.py | 3 - 4 files changed, 36 insertions(+), 105 deletions(-) diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 7f9ebf756f..bf3a72f56d 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import ( TextGenerateRouteChunk, VarGenerateRouteChunk, ) -from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -18,8 +18,10 @@ class AnswerStreamGeneratorRouter: reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] ) -> AnswerStreamGenerateRoute: """ - Get stream generate routes. - :return: + Initializes the stream generation routes for all Answer nodes in the workflow. + This method performs a static analysis of the graph to parse the answer templates. + The old logic for pre-calculating static dependencies has been deprecated and removed, + as the decision logic is now handled dynamically at runtime by the AnswerStreamProcessor. """ # parse stream output node value selectors of answer nodes answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} @@ -31,16 +33,8 @@ class AnswerStreamGeneratorRouter: generate_route = cls._extract_generate_route_selectors(node_config) answer_generate_route[answer_node_id] = generate_route - # fetch answer dependencies - answer_node_ids = list(answer_generate_route.keys()) - answer_dependencies = cls._fetch_answers_dependencies( - answer_node_ids=answer_node_ids, - reverse_edge_mapping=reverse_edge_mapping, - node_id_config_mapping=node_id_config_mapping, - ) - return AnswerStreamGenerateRoute( - answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies + answer_generate_route=answer_generate_route ) @classmethod @@ -99,86 +93,3 @@ class AnswerStreamGeneratorRouter: def _is_variable(cls, part, variable_keys): cleaned_part = part.replace("{{", "").replace("}}", "") return part.startswith("{{") and cleaned_part in variable_keys - - @classmethod - def _fetch_answers_dependencies( - cls, - answer_node_ids: list[str], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_id_config_mapping: dict[str, dict], - ) -> dict[str, list[str]]: - """ - Fetch answer dependencies - :param answer_node_ids: answer node ids - :param reverse_edge_mapping: reverse edge mapping - :param node_id_config_mapping: node id config mapping - :return: - """ - answer_dependencies: dict[str, list[str]] = {} - for answer_node_id in answer_node_ids: - if answer_dependencies.get(answer_node_id) is None: - answer_dependencies[answer_node_id] = [] - - cls._recursive_fetch_answer_dependencies( - current_node_id=answer_node_id, - answer_node_id=answer_node_id, - node_id_config_mapping=node_id_config_mapping, - reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies, - ) - - return answer_dependencies - - @classmethod - def _recursive_fetch_answer_dependencies( - cls, - current_node_id: str, - answer_node_id: str, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - answer_dependencies: dict[str, list[str]], - ) -> None: - """ - Recursive fetch answer dependencies - :param current_node_id: current node id - :param answer_node_id: answer node id - :param node_id_config_mapping: node id config mapping - :param reverse_edge_mapping: reverse edge mapping - :param answer_dependencies: answer dependencies - :return: - """ - reverse_edges = reverse_edge_mapping.get(current_node_id, []) - - # If the current node has more than one incoming edge, it's a join point. - # We should add it as a dependency and stop tracing up further. - if len(reverse_edges) > 1: - answer_dependencies[answer_node_id].append(current_node_id) - return - - for edge in reverse_edges: - source_node_id = edge.source_node_id - if source_node_id not in node_id_config_mapping: - continue - source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - source_node_data = node_id_config_mapping[source_node_id].get("data", {}) - if ( - source_node_type - in { - NodeType.ANSWER, - NodeType.IF_ELSE, - NodeType.QUESTION_CLASSIFIER, - NodeType.ITERATION, - NodeType.LOOP, - NodeType.VARIABLE_ASSIGNER, - } - or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH - ): - answer_dependencies[answer_node_id].append(source_node_id) - else: - cls._recursive_fetch_answer_dependencies( - current_node_id=source_node_id, - answer_node_id=answer_node_id, - node_id_config_mapping=node_id_config_mapping, - reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies, - ) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index d7b6d3d671..51c08a0543 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -76,8 +76,24 @@ class AnswerStreamProcessor(StreamProcessor): def _is_dynamic_dependencies_met(self, start_node_id: str) -> bool: """ - Check if all dynamic dependencies are met for a given node by traversing backwards. - This method is based on the runtime state of the graph. + Performs a dynamic, runtime dependency check by traversing backwards from a given start_node_id. + + This method is the core of the new streaming architecture. Instead of relying on a pre-calculated, + static dependency map, it validates the actual execution path at the moment a stream event is received. + It queries the runtime state of the graph ('the logbook') to ensure that a valid, uninterrupted, + and logically sound path exists from the start_node_id all the way back to the graph's entry point. + + The traversal logic handles: + - Basic node completion states (SUCCEEDED, FAILED, RUNNING). + - Complex branch nodes (If/Else), by checking which branch was actually taken during the run. + Paths from branches that were not taken are considered irrelevant ("parallel universes") and ignored. + + This approach correctly handles complex topologies with join points (nodes with multiple inputs), + ensuring that streaming is only permitted when the true, logical dependency chain for the *current run* + has been successfully completed. + + :param start_node_id: The node ID from which to begin the backward traversal (e.g., the LLM node). + :return: True if all dependencies on the active path are met, False otherwise. """ # Use a queue for BFS and a set to track visited nodes to prevent cycles queue = [start_node_id] @@ -144,10 +160,7 @@ class AnswerStreamProcessor(StreamProcessor): # all depends on answer node id not in rest node ids if event.route_node_state.node_id != answer_node_id and ( answer_node_id not in self.rest_node_ids - or not all( - dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies.get(answer_node_id, []) - ) + or not self._is_dynamic_dependencies_met(answer_node_id) # Using dynamic check for final output as well ): continue @@ -216,7 +229,7 @@ class AnswerStreamProcessor(StreamProcessor): if answer_node_id not in self.rest_node_ids: continue - # New dynamic dependency check + # New dynamic dependency check, replacing the old static dependency list. source_node_id_for_check = event.from_variable_selector[0] all_deps_finished = self._is_dynamic_dependencies_met(start_node_id=source_node_id_for_check) diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 08c356758c..278b26134a 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -23,6 +23,16 @@ class StreamProcessor(ABC): raise NotImplementedError def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None: + """ + Prunes unreachable branches from the `rest_node_ids` list after a branch node has executed. + + This method implements a conservative, non-recursive pruning strategy to prevent a critical bug + where the pruning process would incorrectly "spread" across join points (nodes with multiple inputs) + and erroneously remove shared downstream nodes that should have been preserved. + + By only removing the immediate first node of each determined unreachable branch, we ensure that + the integrity of shared paths in complex graph topologies is maintained. + """ finished_node_id = event.route_node_state.node_id if finished_node_id not in self.rest_node_ids: return diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index a05cc44c99..0c939fece8 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -57,9 +57,6 @@ class AnswerStreamGenerateRoute(BaseModel): AnswerStreamGenerateRoute entity """ - answer_dependencies: dict[str, list[str]] = Field( - ..., description="answer dependencies (answer node id -> dependent answer node ids)" - ) answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( ..., description="answer generate route (answer node id -> generate route chunks)" )