pull/20113/merge
Nguyen Tran Gia Bao 11 months ago committed by GitHub
commit 6fc29bcb74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -71,7 +71,7 @@ class AnswerStreamProcessor(StreamProcessor):
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.rest_node_ids = set(self.graph.node_ids)
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(

@ -14,7 +14,7 @@ class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
self.rest_node_ids = set(graph.node_ids)
@abstractmethod
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
@ -33,8 +33,8 @@ class StreamProcessor(ABC):
return
if run_result.edge_source_handle:
reachable_node_ids: list[str] = []
unreachable_first_node_ids: list[str] = []
reachable_node_ids: set[str] = set()
unreachable_first_node_ids: set[str] = set()
if finished_node_id not in self.graph.edge_mapping:
logger.warning(f"node {finished_node_id} has no edge mapping")
return
@ -57,9 +57,9 @@ class StreamProcessor(ABC):
# The branch_identify parameter is added to ensure that
# only nodes in the correct logical branch are included.
reachable_node_ids.append(edge.target_node_id)
reachable_node_ids.add(edge.target_node_id)
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
reachable_node_ids.extend(ids)
reachable_node_ids |= ids
else:
# if the condition edge in parallel, and the target node is not in parallel, we should not remove it
# Issues: #13626
@ -68,13 +68,16 @@ class StreamProcessor(ABC):
and edge.target_node_id not in self.graph.node_parallel_mapping
):
continue
unreachable_first_node_ids.append(edge.target_node_id)
unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids))
unreachable_first_node_ids.add(edge.target_node_id)
self.rest_node_ids |= reachable_node_ids
unreachable_first_node_ids -= reachable_node_ids
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
node_ids = []
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> set[str]:
node_ids = set()
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
@ -84,22 +87,17 @@ class StreamProcessor(ABC):
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
node_ids.add(edge.target_node_id)
node_ids |= self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify)
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: set[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
if node_id in reachable_node_ids:
return
self.rest_node_ids.remove(node_id)
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:

@ -80,7 +80,7 @@ class EndStreamProcessor(StreamProcessor):
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.rest_node_ids = set(self.graph.node_ids)
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(

Loading…
Cancel
Save