diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index ba6ba16e36..dc48a0f4e7 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -71,7 +71,7 @@ class AnswerStreamProcessor(StreamProcessor): self.route_position = {} for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): self.route_position[answer_node_id] = 0 - self.rest_node_ids = self.graph.node_ids.copy() + self.rest_node_ids = set(self.graph.node_ids) self.current_stream_chunk_generating_node_ids = {} def _generate_stream_outputs_when_node_finished( diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 6671ff0746..f104eed9e6 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -14,7 +14,7 @@ class StreamProcessor(ABC): def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: self.graph = graph self.variable_pool = variable_pool - self.rest_node_ids = graph.node_ids.copy() + self.rest_node_ids = set(graph.node_ids) @abstractmethod def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: @@ -33,8 +33,8 @@ class StreamProcessor(ABC): return if run_result.edge_source_handle: - reachable_node_ids: list[str] = [] - unreachable_first_node_ids: list[str] = [] + reachable_node_ids: set[str] = set() + unreachable_first_node_ids: set[str] = set() if finished_node_id not in self.graph.edge_mapping: logger.warning(f"node {finished_node_id} has no edge mapping") return @@ -57,9 +57,9 @@ class StreamProcessor(ABC): # The branch_identify parameter is added to ensure that # only nodes in the correct logical branch are included. - reachable_node_ids.append(edge.target_node_id) + reachable_node_ids.add(edge.target_node_id) ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle) - reachable_node_ids.extend(ids) + reachable_node_ids.update(ids) else: # if the condition edge in parallel, and the target node is not in parallel, we should not remove it # Issues: #13626 @@ -68,13 +68,13 @@ class StreamProcessor(ABC): and edge.target_node_id not in self.graph.node_parallel_mapping ): continue - unreachable_first_node_ids.append(edge.target_node_id) - unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids)) + unreachable_first_node_ids.add(edge.target_node_id) + unreachable_first_node_ids = unreachable_first_node_ids - reachable_node_ids for node_id in unreachable_first_node_ids: self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) - def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: - node_ids = [] + def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> set[str]: + node_ids = set() for edge in self.graph.edge_mapping.get(node_id, []): if edge.target_node_id == self.graph.root_node_id: continue @@ -84,22 +84,18 @@ class StreamProcessor(ABC): if not branch_identify or edge.run_condition.branch_identify != branch_identify: continue - node_ids.append(edge.target_node_id) - node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify)) + node_ids.add(edge.target_node_id) + node_ids.update(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify)) return node_ids - def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: + def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: set[str]) -> None: """ remove target node ids until merge """ if node_id not in self.rest_node_ids: return - - if node_id in reachable_node_ids: - return - self.rest_node_ids.remove(node_id) - self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids)) + self.rest_node_ids.update(reachable_node_ids - self.rest_node_ids) for edge in self.graph.edge_mapping.get(node_id, []): if edge.target_node_id in reachable_node_ids: diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 3ae5af7137..d0555136de 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -80,7 +80,7 @@ class EndStreamProcessor(StreamProcessor): self.route_position = {} for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): self.route_position[end_node_id] = 0 - self.rest_node_ids = self.graph.node_ids.copy() + self.rest_node_ids = set(self.graph.node_ids) self.current_stream_chunk_generating_node_ids = {} def _generate_stream_outputs_when_node_finished(