From b30db668587748e7ff1bcbecd94dcf5f7c289efa Mon Sep 17 00:00:00 2001 From: Gia Bao Date: Thu, 22 May 2025 15:37:23 +0700 Subject: [PATCH] fix: Resolve streaming response issue in case of nested conditions node --- .../workflow/nodes/answer/base_stream_processor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index f104eed9e6..f58cc5d7a1 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -59,7 +59,7 @@ class StreamProcessor(ABC): # only nodes in the correct logical branch are included. reachable_node_ids.add(edge.target_node_id) ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle) - reachable_node_ids.update(ids) + reachable_node_ids |= ids else: # if the condition edge in parallel, and the target node is not in parallel, we should not remove it # Issues: #13626 @@ -69,7 +69,10 @@ class StreamProcessor(ABC): ): continue unreachable_first_node_ids.add(edge.target_node_id) - unreachable_first_node_ids = unreachable_first_node_ids - reachable_node_ids + + self.rest_node_ids |= reachable_node_ids + + unreachable_first_node_ids -= reachable_node_ids for node_id in unreachable_first_node_ids: self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) @@ -85,7 +88,7 @@ class StreamProcessor(ABC): continue node_ids.add(edge.target_node_id) - node_ids.update(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify)) + node_ids |= self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify) return node_ids def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: set[str]) -> None: @@ -95,7 +98,6 @@ class StreamProcessor(ABC): if node_id not in self.rest_node_ids: return self.rest_node_ids.remove(node_id) - self.rest_node_ids.update(reachable_node_ids - self.rest_node_ids) for edge in self.graph.edge_mapping.get(node_id, []): if edge.target_node_id in reachable_node_ids: