diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b315129763..daf6efb952 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -152,7 +152,9 @@ class GraphEngine: try: if self.init_params.workflow_type == WorkflowType.CHAT: stream_processor = AnswerStreamProcessor( - graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + graph=self.graph, + variable_pool=self.graph_runtime_state.variable_pool, + node_run_state=self.graph_runtime_state.node_run_state, ) else: stream_processor = EndStreamProcessor( diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 1d9c3e9b96..bf3a72f56d 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import ( TextGenerateRouteChunk, VarGenerateRouteChunk, ) -from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -18,8 +18,10 @@ class AnswerStreamGeneratorRouter: reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] ) -> AnswerStreamGenerateRoute: """ - Get stream generate routes. - :return: + Initializes the stream generation routes for all Answer nodes in the workflow. + This method performs a static analysis of the graph to parse the answer templates. + The old logic for pre-calculating static dependencies has been deprecated and removed, + as the decision logic is now handled dynamically at runtime by the AnswerStreamProcessor. """ # parse stream output node value selectors of answer nodes answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} @@ -31,16 +33,8 @@ class AnswerStreamGeneratorRouter: generate_route = cls._extract_generate_route_selectors(node_config) answer_generate_route[answer_node_id] = generate_route - # fetch answer dependencies - answer_node_ids = list(answer_generate_route.keys()) - answer_dependencies = cls._fetch_answers_dependencies( - answer_node_ids=answer_node_ids, - reverse_edge_mapping=reverse_edge_mapping, - node_id_config_mapping=node_id_config_mapping, - ) - return AnswerStreamGenerateRoute( - answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies + answer_generate_route=answer_generate_route ) @classmethod @@ -90,85 +84,12 @@ class AnswerStreamGeneratorRouter: :return: """ node_data = AnswerNodeData(**config.get("data", {})) + # Trim whitespace from the answer template to prevent parsing issues with leading/trailing spaces. + if node_data.answer: + node_data.answer = node_data.answer.strip() return cls.extract_generate_route_from_node_data(node_data) @classmethod def _is_variable(cls, part, variable_keys): cleaned_part = part.replace("{{", "").replace("}}", "") return part.startswith("{{") and cleaned_part in variable_keys - - @classmethod - def _fetch_answers_dependencies( - cls, - answer_node_ids: list[str], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - node_id_config_mapping: dict[str, dict], - ) -> dict[str, list[str]]: - """ - Fetch answer dependencies - :param answer_node_ids: answer node ids - :param reverse_edge_mapping: reverse edge mapping - :param node_id_config_mapping: node id config mapping - :return: - """ - answer_dependencies: dict[str, list[str]] = {} - for answer_node_id in answer_node_ids: - if answer_dependencies.get(answer_node_id) is None: - answer_dependencies[answer_node_id] = [] - - cls._recursive_fetch_answer_dependencies( - current_node_id=answer_node_id, - answer_node_id=answer_node_id, - node_id_config_mapping=node_id_config_mapping, - reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies, - ) - - return answer_dependencies - - @classmethod - def _recursive_fetch_answer_dependencies( - cls, - current_node_id: str, - answer_node_id: str, - node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - answer_dependencies: dict[str, list[str]], - ) -> None: - """ - Recursive fetch answer dependencies - :param current_node_id: current node id - :param answer_node_id: answer node id - :param node_id_config_mapping: node id config mapping - :param reverse_edge_mapping: reverse edge mapping - :param answer_dependencies: answer dependencies - :return: - """ - reverse_edges = reverse_edge_mapping.get(current_node_id, []) - for edge in reverse_edges: - source_node_id = edge.source_node_id - if source_node_id not in node_id_config_mapping: - continue - source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - source_node_data = node_id_config_mapping[source_node_id].get("data", {}) - if ( - source_node_type - in { - NodeType.ANSWER, - NodeType.IF_ELSE, - NodeType.QUESTION_CLASSIFIER, - NodeType.ITERATION, - NodeType.LOOP, - NodeType.VARIABLE_ASSIGNER, - } - or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH - ): - answer_dependencies[answer_node_id].append(source_node_id) - else: - cls._recursive_fetch_answer_dependencies( - current_node_id=source_node_id, - answer_node_id=answer_node_id, - node_id_config_mapping=node_id_config_mapping, - reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies, - ) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 97666fad05..51c08a0543 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -11,6 +11,7 @@ from core.workflow.graph_engine.entities.event import ( NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk @@ -18,8 +19,8 @@ logger = logging.getLogger(__name__) class AnswerStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: - super().__init__(graph, variable_pool) + def __init__(self, graph: Graph, variable_pool: VariablePool, node_run_state: RuntimeRouteState) -> None: + super().__init__(graph, variable_pool, node_run_state) self.generate_routes = graph.answer_stream_generate_routes self.route_position = {} for answer_node_id in self.generate_routes.answer_generate_route: @@ -73,6 +74,80 @@ class AnswerStreamProcessor(StreamProcessor): self.rest_node_ids = self.graph.node_ids.copy() self.current_stream_chunk_generating_node_ids = {} + def _is_dynamic_dependencies_met(self, start_node_id: str) -> bool: + """ + Performs a dynamic, runtime dependency check by traversing backwards from a given start_node_id. + + This method is the core of the new streaming architecture. Instead of relying on a pre-calculated, + static dependency map, it validates the actual execution path at the moment a stream event is received. + It queries the runtime state of the graph ('the logbook') to ensure that a valid, uninterrupted, + and logically sound path exists from the start_node_id all the way back to the graph's entry point. + + The traversal logic handles: + - Basic node completion states (SUCCEEDED, FAILED, RUNNING). + - Complex branch nodes (If/Else), by checking which branch was actually taken during the run. + Paths from branches that were not taken are considered irrelevant ("parallel universes") and ignored. + + This approach correctly handles complex topologies with join points (nodes with multiple inputs), + ensuring that streaming is only permitted when the true, logical dependency chain for the *current run* + has been successfully completed. + + :param start_node_id: The node ID from which to begin the backward traversal (e.g., the LLM node). + :return: True if all dependencies on the active path are met, False otherwise. + """ + # Use a queue for BFS and a set to track visited nodes to prevent cycles + queue = [start_node_id] + visited = {start_node_id} + + while queue: + current_node_id = queue.pop(0) + + # Get the edges leading to the current node + parent_edges = self.graph.reverse_edge_mapping.get(current_node_id, []) + if not parent_edges: + continue + + for edge in parent_edges: + parent_node_id = edge.source_node_id + + if parent_node_id in visited: + continue + + visited.add(parent_node_id) + + # Find the latest execution state of the parent node in the current run + parent_node_run_state = None + for state in self.node_run_state.node_state_mapping.values(): + if state.node_id == parent_node_id: + parent_node_run_state = state + break # Assume the last found state is the latest for simplicity + + if not parent_node_run_state or parent_node_run_state.status == RouteNodeState.Status.RUNNING: + return False + + if parent_node_run_state.status in [RouteNodeState.Status.FAILED, RouteNodeState.Status.EXCEPTION]: + return False + + # If the parent is a branch node, check if the executed branch leads to the current node + parent_node_config = self.graph.node_id_config_mapping.get(parent_node_id, {}) + parent_node_type = parent_node_config.get('data', {}).get('type') + + is_branch_node = parent_node_type in ['if-else', 'question-classifier'] # Example branch types + + if is_branch_node: + run_result = parent_node_run_state.node_run_result + chosen_handle = run_result.edge_source_handle if run_result else None + required_handle = edge.run_condition.branch_identify if edge.run_condition else None + + # If the chosen branch does not match the path we are traversing, this dependency path is irrelevant + if chosen_handle and required_handle and chosen_handle != required_handle: + continue # This path was not taken, so it's not a dependency + + # If all checks pass, add the parent to the queue to continue traversing up + queue.append(parent_node_id) + + return True + def _generate_stream_outputs_when_node_finished( self, event: NodeRunSucceededEvent ) -> Generator[GraphEngineEvent, None, None]: @@ -85,10 +160,7 @@ class AnswerStreamProcessor(StreamProcessor): # all depends on answer node id not in rest node ids if event.route_node_state.node_id != answer_node_id and ( answer_node_id not in self.rest_node_ids - or not all( - dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id] - ) + or not self._is_dynamic_dependencies_met(answer_node_id) # Using dynamic check for final output as well ): continue @@ -156,33 +228,13 @@ class AnswerStreamProcessor(StreamProcessor): for answer_node_id, route_position in self.route_position.items(): if answer_node_id not in self.rest_node_ids: continue - # Remove current node id from answer dependencies to support stream output if it is a success branch - answer_dependencies = self.generate_routes.answer_dependencies - edge_mapping = self.graph.edge_mapping.get(event.node_id) - success_edge = ( - next( - ( - edge - for edge in edge_mapping - if edge.run_condition - and edge.run_condition.type == "branch_identify" - and edge.run_condition.branch_identify == "success-branch" - ), - None, - ) - if edge_mapping - else None - ) - if ( - event.node_id in answer_dependencies[answer_node_id] - and success_edge - and success_edge.target_node_id == answer_node_id - ): - answer_dependencies[answer_node_id].remove(event.node_id) - answer_dependencies_ids = answer_dependencies.get(answer_node_id, []) - # all depends on answer node id not in rest node ids - if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids): - if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): + + # New dynamic dependency check, replacing the old static dependency list. + source_node_id_for_check = event.from_variable_selector[0] + all_deps_finished = self._is_dynamic_dependencies_met(start_node_id=source_node_id_for_check) + + if all_deps_finished: + if route_position >= len(self.generate_routes.answer_generate_route.get(answer_node_id, [])): continue route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 09d5464d7a..278b26134a 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -6,14 +6,16 @@ from typing import Optional from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState logger = logging.getLogger(__name__) class StreamProcessor(ABC): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + def __init__(self, graph: Graph, variable_pool: VariablePool, node_run_state: RuntimeRouteState) -> None: self.graph = graph self.variable_pool = variable_pool + self.node_run_state = node_run_state self.rest_node_ids = graph.node_ids.copy() @abstractmethod @@ -21,6 +23,16 @@ class StreamProcessor(ABC): raise NotImplementedError def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None: + """ + Prunes unreachable branches from the `rest_node_ids` list after a branch node has executed. + + This method implements a conservative, non-recursive pruning strategy to prevent a critical bug + where the pruning process would incorrectly "spread" across join points (nodes with multiple inputs) + and erroneously remove shared downstream nodes that should have been preserved. + + By only removing the immediate first node of each determined unreachable branch, we ensure that + the integrity of shared paths in complex graph topologies is maintained. + """ finished_node_id = event.route_node_state.node_id if finished_node_id not in self.rest_node_ids: return @@ -68,9 +80,14 @@ class StreamProcessor(ABC): ): continue unreachable_first_node_ids.append(edge.target_node_id) - unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids)) - for node_id in unreachable_first_node_ids: - self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) + + # Instead of recursively removing the entire unreachable branch, + # which can cause issues with complex join points, + # we will only remove the immediate first node of the unreachable branch. + # This prevents the removal logic from incorrectly pruning shared paths downstream. + for node_id in list(set(unreachable_first_node_ids) - set(reachable_node_ids)): + if node_id in self.rest_node_ids: + self.rest_node_ids.remove(node_id) def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: if node_id not in self.rest_node_ids: diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index a05cc44c99..0c939fece8 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -57,9 +57,6 @@ class AnswerStreamGenerateRoute(BaseModel): AnswerStreamGenerateRoute entity """ - answer_dependencies: dict[str, list[str]] = Field( - ..., description="answer dependencies (answer node id -> dependent answer node ids)" - ) answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( ..., description="answer generate route (answer node id -> generate route chunks)" )