refactor(workflow): Rearchitect stream dependency logic for complex graphs

This commit addresses a critical issue where streaming output would fail in workflows with complex topologies, particularly those involving multiple conditional branches (if/else) that converge on a common node before the LLM and Answer nodes.

The root cause was twofold:
1. A bug in the branch pruning logic () that would incorrectly remove shared downstream nodes, leading to a premature emptying of the  list.
2. A flawed static dependency analysis () that could not correctly resolve dependencies for nodes that were part of multiple, mutually exclusive execution paths.

This refactor introduces a new, robust architecture for streaming dependency management based on the principle of "Static Pre-pruning + Dynamic Adjudication":

- **Fix**: The branch pruning logic in  is now non-recursive and conservative. It only prunes the immediate first node of an unreachable branch, preserving the integrity of shared downstream paths and join points.

- **Refactor**: The old static dependency analysis has been completely removed. This includes deleting the  attribute from the  entity and deleting the associated recursive dependency fetching methods (, ).

- **Feat**: A new method, , has been implemented in . This method performs a real-time, backward traversal of the graph from the streaming node, querying the runtime execution state () to dynamically validate if the *actual* dependency path has been successfully completed. This ensures that streaming decisions are based on the ground truth of the current execution, not a flawed static prediction.

- **Doc**: Added comprehensive docstrings and comments to the modified components to explain the new architecture and the rationale behind the changes.
pull/22771/head
xuzijie1995 7 months ago
parent b13ae784a4
commit db91643915

@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import (
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -18,8 +18,10 @@ class AnswerStreamGeneratorRouter:
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
"""
Get stream generate routes.
:return:
Initializes the stream generation routes for all Answer nodes in the workflow.
This method performs a static analysis of the graph to parse the answer templates.
The old logic for pre-calculating static dependencies has been deprecated and removed,
as the decision logic is now handled dynamically at runtime by the AnswerStreamProcessor.
"""
# parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
@ -31,16 +33,8 @@ class AnswerStreamGeneratorRouter:
generate_route = cls._extract_generate_route_selectors(node_config)
answer_generate_route[answer_node_id] = generate_route
# fetch answer dependencies
answer_node_ids = list(answer_generate_route.keys())
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping,
)
return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies
answer_generate_route=answer_generate_route
)
@classmethod
@ -99,86 +93,3 @@ class AnswerStreamGeneratorRouter:
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace("{{", "").replace("}}", "")
return part.startswith("{{") and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(
cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
answer_dependencies: dict[str, list[str]] = {}
for answer_node_id in answer_node_ids:
if answer_dependencies.get(answer_node_id) is None:
answer_dependencies[answer_node_id] = []
cls._recursive_fetch_answer_dependencies(
current_node_id=answer_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(
cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
# If the current node has more than one incoming edge, it's a join point.
# We should add it as a dependency and stop tracing up further.
if len(reverse_edges) > 1:
answer_dependencies[answer_node_id].append(current_node_id)
return
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (
source_node_type
in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.LOOP,
NodeType.VARIABLE_ASSIGNER,
}
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)

@ -76,8 +76,24 @@ class AnswerStreamProcessor(StreamProcessor):
def _is_dynamic_dependencies_met(self, start_node_id: str) -> bool:
"""
Check if all dynamic dependencies are met for a given node by traversing backwards.
This method is based on the runtime state of the graph.
Performs a dynamic, runtime dependency check by traversing backwards from a given start_node_id.
This method is the core of the new streaming architecture. Instead of relying on a pre-calculated,
static dependency map, it validates the actual execution path at the moment a stream event is received.
It queries the runtime state of the graph ('the logbook') to ensure that a valid, uninterrupted,
and logically sound path exists from the start_node_id all the way back to the graph's entry point.
The traversal logic handles:
- Basic node completion states (SUCCEEDED, FAILED, RUNNING).
- Complex branch nodes (If/Else), by checking which branch was actually taken during the run.
Paths from branches that were not taken are considered irrelevant ("parallel universes") and ignored.
This approach correctly handles complex topologies with join points (nodes with multiple inputs),
ensuring that streaming is only permitted when the true, logical dependency chain for the *current run*
has been successfully completed.
:param start_node_id: The node ID from which to begin the backward traversal (e.g., the LLM node).
:return: True if all dependencies on the active path are met, False otherwise.
"""
# Use a queue for BFS and a set to track visited nodes to prevent cycles
queue = [start_node_id]
@ -144,10 +160,7 @@ class AnswerStreamProcessor(StreamProcessor):
# all depends on answer node id not in rest node ids
if event.route_node_state.node_id != answer_node_id and (
answer_node_id not in self.rest_node_ids
or not all(
dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies.get(answer_node_id, [])
)
or not self._is_dynamic_dependencies_met(answer_node_id) # Using dynamic check for final output as well
):
continue
@ -216,7 +229,7 @@ class AnswerStreamProcessor(StreamProcessor):
if answer_node_id not in self.rest_node_ids:
continue
# New dynamic dependency check
# New dynamic dependency check, replacing the old static dependency list.
source_node_id_for_check = event.from_variable_selector[0]
all_deps_finished = self._is_dynamic_dependencies_met(start_node_id=source_node_id_for_check)

@ -23,6 +23,16 @@ class StreamProcessor(ABC):
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
"""
Prunes unreachable branches from the `rest_node_ids` list after a branch node has executed.
This method implements a conservative, non-recursive pruning strategy to prevent a critical bug
where the pruning process would incorrectly "spread" across join points (nodes with multiple inputs)
and erroneously remove shared downstream nodes that should have been preserved.
By only removing the immediate first node of each determined unreachable branch, we ensure that
the integrity of shared paths in complex graph topologies is maintained.
"""
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return

@ -57,9 +57,6 @@ class AnswerStreamGenerateRoute(BaseModel):
AnswerStreamGenerateRoute entity
"""
answer_dependencies: dict[str, list[str]] = Field(
..., description="answer dependencies (answer node id -> dependent answer node ids)"
)
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
..., description="answer generate route (answer node id -> generate route chunks)"
)

Loading…
Cancel
Save