diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b315129763..daf6efb952 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -152,7 +152,9 @@ class GraphEngine: try: if self.init_params.workflow_type == WorkflowType.CHAT: stream_processor = AnswerStreamProcessor( - graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + graph=self.graph, + variable_pool=self.graph_runtime_state.variable_pool, + node_run_state=self.graph_runtime_state.node_run_state, ) else: stream_processor = EndStreamProcessor( diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 1d9c3e9b96..7f9ebf756f 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -90,6 +90,9 @@ class AnswerStreamGeneratorRouter: :return: """ node_data = AnswerNodeData(**config.get("data", {})) + # Trim whitespace from the answer template to prevent parsing issues with leading/trailing spaces. + if node_data.answer: + node_data.answer = node_data.answer.strip() return cls.extract_generate_route_from_node_data(node_data) @classmethod @@ -145,6 +148,13 @@ class AnswerStreamGeneratorRouter: :return: """ reverse_edges = reverse_edge_mapping.get(current_node_id, []) + + # If the current node has more than one incoming edge, it's a join point. + # We should add it as a dependency and stop tracing up further. + if len(reverse_edges) > 1: + answer_dependencies[answer_node_id].append(current_node_id) + return + for edge in reverse_edges: source_node_id = edge.source_node_id if source_node_id not in node_id_config_mapping: diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 97666fad05..d7b6d3d671 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -11,6 +11,7 @@ from core.workflow.graph_engine.entities.event import ( NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk @@ -18,8 +19,8 @@ logger = logging.getLogger(__name__) class AnswerStreamProcessor(StreamProcessor): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: - super().__init__(graph, variable_pool) + def __init__(self, graph: Graph, variable_pool: VariablePool, node_run_state: RuntimeRouteState) -> None: + super().__init__(graph, variable_pool, node_run_state) self.generate_routes = graph.answer_stream_generate_routes self.route_position = {} for answer_node_id in self.generate_routes.answer_generate_route: @@ -73,6 +74,64 @@ class AnswerStreamProcessor(StreamProcessor): self.rest_node_ids = self.graph.node_ids.copy() self.current_stream_chunk_generating_node_ids = {} + def _is_dynamic_dependencies_met(self, start_node_id: str) -> bool: + """ + Check if all dynamic dependencies are met for a given node by traversing backwards. + This method is based on the runtime state of the graph. + """ + # Use a queue for BFS and a set to track visited nodes to prevent cycles + queue = [start_node_id] + visited = {start_node_id} + + while queue: + current_node_id = queue.pop(0) + + # Get the edges leading to the current node + parent_edges = self.graph.reverse_edge_mapping.get(current_node_id, []) + if not parent_edges: + continue + + for edge in parent_edges: + parent_node_id = edge.source_node_id + + if parent_node_id in visited: + continue + + visited.add(parent_node_id) + + # Find the latest execution state of the parent node in the current run + parent_node_run_state = None + for state in self.node_run_state.node_state_mapping.values(): + if state.node_id == parent_node_id: + parent_node_run_state = state + break # Assume the last found state is the latest for simplicity + + if not parent_node_run_state or parent_node_run_state.status == RouteNodeState.Status.RUNNING: + return False + + if parent_node_run_state.status in [RouteNodeState.Status.FAILED, RouteNodeState.Status.EXCEPTION]: + return False + + # If the parent is a branch node, check if the executed branch leads to the current node + parent_node_config = self.graph.node_id_config_mapping.get(parent_node_id, {}) + parent_node_type = parent_node_config.get('data', {}).get('type') + + is_branch_node = parent_node_type in ['if-else', 'question-classifier'] # Example branch types + + if is_branch_node: + run_result = parent_node_run_state.node_run_result + chosen_handle = run_result.edge_source_handle if run_result else None + required_handle = edge.run_condition.branch_identify if edge.run_condition else None + + # If the chosen branch does not match the path we are traversing, this dependency path is irrelevant + if chosen_handle and required_handle and chosen_handle != required_handle: + continue # This path was not taken, so it's not a dependency + + # If all checks pass, add the parent to the queue to continue traversing up + queue.append(parent_node_id) + + return True + def _generate_stream_outputs_when_node_finished( self, event: NodeRunSucceededEvent ) -> Generator[GraphEngineEvent, None, None]: @@ -87,7 +146,7 @@ class AnswerStreamProcessor(StreamProcessor): answer_node_id not in self.rest_node_ids or not all( dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + for dep_id in self.generate_routes.answer_dependencies.get(answer_node_id, []) ) ): continue @@ -156,33 +215,13 @@ class AnswerStreamProcessor(StreamProcessor): for answer_node_id, route_position in self.route_position.items(): if answer_node_id not in self.rest_node_ids: continue - # Remove current node id from answer dependencies to support stream output if it is a success branch - answer_dependencies = self.generate_routes.answer_dependencies - edge_mapping = self.graph.edge_mapping.get(event.node_id) - success_edge = ( - next( - ( - edge - for edge in edge_mapping - if edge.run_condition - and edge.run_condition.type == "branch_identify" - and edge.run_condition.branch_identify == "success-branch" - ), - None, - ) - if edge_mapping - else None - ) - if ( - event.node_id in answer_dependencies[answer_node_id] - and success_edge - and success_edge.target_node_id == answer_node_id - ): - answer_dependencies[answer_node_id].remove(event.node_id) - answer_dependencies_ids = answer_dependencies.get(answer_node_id, []) - # all depends on answer node id not in rest node ids - if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids): - if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): + + # New dynamic dependency check + source_node_id_for_check = event.from_variable_selector[0] + all_deps_finished = self._is_dynamic_dependencies_met(start_node_id=source_node_id_for_check) + + if all_deps_finished: + if route_position >= len(self.generate_routes.answer_generate_route.get(answer_node_id, [])): continue route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 09d5464d7a..08c356758c 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -6,14 +6,16 @@ from typing import Optional from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState logger = logging.getLogger(__name__) class StreamProcessor(ABC): - def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + def __init__(self, graph: Graph, variable_pool: VariablePool, node_run_state: RuntimeRouteState) -> None: self.graph = graph self.variable_pool = variable_pool + self.node_run_state = node_run_state self.rest_node_ids = graph.node_ids.copy() @abstractmethod @@ -68,9 +70,14 @@ class StreamProcessor(ABC): ): continue unreachable_first_node_ids.append(edge.target_node_id) - unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids)) - for node_id in unreachable_first_node_ids: - self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) + + # Instead of recursively removing the entire unreachable branch, + # which can cause issues with complex join points, + # we will only remove the immediate first node of the unreachable branch. + # This prevents the removal logic from incorrectly pruning shared paths downstream. + for node_id in list(set(unreachable_first_node_ids) - set(reachable_node_ids)): + if node_id in self.rest_node_ids: + self.rest_node_ids.remove(node_id) def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: if node_id not in self.rest_node_ids: