|
|
|
|
@ -14,7 +14,7 @@ class StreamProcessor(ABC):
|
|
|
|
|
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
|
|
|
|
self.graph = graph
|
|
|
|
|
self.variable_pool = variable_pool
|
|
|
|
|
self.rest_node_ids = graph.node_ids.copy()
|
|
|
|
|
self.rest_node_ids = set(graph.node_ids)
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]:
|
|
|
|
|
@ -33,8 +33,8 @@ class StreamProcessor(ABC):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if run_result.edge_source_handle:
|
|
|
|
|
reachable_node_ids: list[str] = []
|
|
|
|
|
unreachable_first_node_ids: list[str] = []
|
|
|
|
|
reachable_node_ids: set[str] = set()
|
|
|
|
|
unreachable_first_node_ids: set[str] = set()
|
|
|
|
|
if finished_node_id not in self.graph.edge_mapping:
|
|
|
|
|
logger.warning(f"node {finished_node_id} has no edge mapping")
|
|
|
|
|
return
|
|
|
|
|
@ -57,9 +57,9 @@ class StreamProcessor(ABC):
|
|
|
|
|
|
|
|
|
|
# The branch_identify parameter is added to ensure that
|
|
|
|
|
# only nodes in the correct logical branch are included.
|
|
|
|
|
reachable_node_ids.append(edge.target_node_id)
|
|
|
|
|
reachable_node_ids.add(edge.target_node_id)
|
|
|
|
|
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
|
|
|
|
|
reachable_node_ids.extend(ids)
|
|
|
|
|
reachable_node_ids |= ids
|
|
|
|
|
else:
|
|
|
|
|
# if the condition edge in parallel, and the target node is not in parallel, we should not remove it
|
|
|
|
|
# Issues: #13626
|
|
|
|
|
@ -68,13 +68,16 @@ class StreamProcessor(ABC):
|
|
|
|
|
and edge.target_node_id not in self.graph.node_parallel_mapping
|
|
|
|
|
):
|
|
|
|
|
continue
|
|
|
|
|
unreachable_first_node_ids.append(edge.target_node_id)
|
|
|
|
|
unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids))
|
|
|
|
|
unreachable_first_node_ids.add(edge.target_node_id)
|
|
|
|
|
|
|
|
|
|
self.rest_node_ids |= reachable_node_ids
|
|
|
|
|
|
|
|
|
|
unreachable_first_node_ids -= reachable_node_ids
|
|
|
|
|
for node_id in unreachable_first_node_ids:
|
|
|
|
|
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
|
|
|
|
|
|
|
|
|
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
|
|
|
|
|
node_ids = []
|
|
|
|
|
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> set[str]:
|
|
|
|
|
node_ids = set()
|
|
|
|
|
for edge in self.graph.edge_mapping.get(node_id, []):
|
|
|
|
|
if edge.target_node_id == self.graph.root_node_id:
|
|
|
|
|
continue
|
|
|
|
|
@ -84,22 +87,17 @@ class StreamProcessor(ABC):
|
|
|
|
|
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
node_ids.append(edge.target_node_id)
|
|
|
|
|
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
|
|
|
|
|
node_ids.add(edge.target_node_id)
|
|
|
|
|
node_ids |= self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify)
|
|
|
|
|
return node_ids
|
|
|
|
|
|
|
|
|
|
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
|
|
|
|
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: set[str]) -> None:
|
|
|
|
|
"""
|
|
|
|
|
remove target node ids until merge
|
|
|
|
|
"""
|
|
|
|
|
if node_id not in self.rest_node_ids:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if node_id in reachable_node_ids:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.rest_node_ids.remove(node_id)
|
|
|
|
|
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
|
|
|
|
|
|
|
|
|
|
for edge in self.graph.edge_mapping.get(node_id, []):
|
|
|
|
|
if edge.target_node_id in reachable_node_ids:
|
|
|
|
|
|