Remove unused functions in class Graph

pull/22791/head
Yongtao Huang 7 months ago
parent c987001a19
commit 8a731f6dbe

@ -204,47 +204,6 @@ class Graph(BaseModel):
return graph
def add_extra_edge(
self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None
) -> None:
"""
Add extra edge to the graph
:param source_node_id: source node id
:param target_node_id: target node id
:param run_condition: run condition
"""
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
return
if source_node_id not in self.edge_mapping:
self.edge_mapping[source_node_id] = []
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
return
graph_edge = GraphEdge(
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
)
self.edge_mapping[source_node_id].append(graph_edge)
def get_leaf_node_ids(self) -> list[str]:
"""
Get leaf node ids of the graph
:return: leaf node ids
"""
leaf_node_ids = []
for node_id in self.node_ids:
if node_id not in self.edge_mapping or (
len(self.edge_mapping[node_id]) == 1
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id
):
leaf_node_ids.append(node_id)
return leaf_node_ids
@classmethod
def _recursively_add_node_ids(
cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str

@ -162,14 +162,6 @@ def test__init_iteration_graph():
}
graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration")
graph.add_extra_edge(
source_node_id="answer-in-iteration",
target_node_id="template-transform-in-iteration",
run_condition=RunCondition(
type="condition",
conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="", value="5")],
),
)
# iteration:
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
@ -177,7 +169,6 @@ def test__init_iteration_graph():
assert graph.root_node_id == "template-transform-in-iteration"
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
def test_parallels_graph():

Loading…
Cancel
Save