chore(workflow): Optimize the iteration when selecting a variable from a branch in the output variable causes iteration index err (#8440)

pull/8387/head
takatost 2 years ago committed by GitHub
parent d882348f39
commit 88c9834ef2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -689,23 +689,11 @@ class Graph(BaseModel):
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
parallel_start_node_id = None for _, branch_node_ids in parallel_start_node_ids.items():
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == set(routes_node_ids.keys()): if set(branch_node_ids) == set(routes_node_ids.keys()):
parallel_start_node_id = p_start_node_id
return True return True
if not parallel_start_node_id: return False
raise Exception("Parallel start node id not found")
for graph_edge in reverse_edge_mapping[start_node_id]:
if (
graph_edge.source_node_id not in all_routes_node_ids
or graph_edge.source_node_id != parallel_start_node_id
):
return False
return True
@classmethod @classmethod
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:

@ -20,11 +20,9 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iteration.entities import IterationNodeData from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,38 +66,6 @@ class IterationNode(BaseNode):
if not iteration_graph: if not iteration_graph:
raise ValueError("iteration graph not found") raise ValueError("iteration graph not found")
leaf_node_ids = iteration_graph.get_leaf_node_ids()
iteration_leaf_node_ids = []
for leaf_node_id in leaf_node_ids:
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
if not node_config:
continue
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
if not leaf_node_iteration_id:
continue
if leaf_node_iteration_id != self.node_id:
continue
iteration_leaf_node_ids.append(leaf_node_id)
# add condition of end nodes to root node
iteration_graph.add_extra_edge(
source_node_id=leaf_node_id,
target_node_id=root_node_id,
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=[self.node_id, "index"],
comparison_operator="<",
value=str(len(iterator_list_value)),
)
],
),
)
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool # append iteration variable (item, index) to variable pool
@ -149,91 +115,90 @@ class IterationNode(BaseNode):
outputs: list[Any] = [] outputs: list[Any] = []
try: try:
# run workflow for _ in range(len(iterator_list_value)):
rst = graph_engine.run() # run workflow
for event in rst: rst = graph_engine.run()
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: for event in rst:
event.in_iteration_id = self.node_id if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if (
isinstance(event, BaseNodeEvent) if (
and event.node_type == NodeType.ITERATION_START isinstance(event, BaseNodeEvent)
and not isinstance(event, NodeRunStreamChunkEvent) and event.node_type == NodeType.ITERATION_START
): and not isinstance(event, NodeRunStreamChunkEvent)
continue ):
continue
if isinstance(event, NodeRunSucceededEvent):
if event.route_node_state.node_run_result: if isinstance(event, NodeRunSucceededEvent):
metadata = event.route_node_state.node_run_result.metadata if event.route_node_state.node_run_result:
if not metadata: metadata = event.route_node_state.node_run_result.metadata
metadata = {} if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any( metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
[self.node_id, "index"] metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
) [self.node_id, "index"]
event.route_node_state.node_run_result.metadata = metadata )
event.route_node_state.node_run_result.metadata = metadata
yield event
yield event
# handle iteration run result elif isinstance(event, BaseGraphEvent):
if event.route_node_state.node_id in iteration_leaf_node_ids: if isinstance(event, GraphRunFailedEvent):
# append to iteration output variable list # iteration run failed
current_iteration_output = variable_pool.get_any(self.node_data.output_selector) yield IterationRunFailedEvent(
outputs.append(current_iteration_output) iteration_id=self.id,
iteration_node_id=self.node_id,
# remove all nodes outputs from variable pool iteration_node_type=self.node_type,
for node_id in iteration_graph.node_ids: iteration_node_data=self.node_data,
variable_pool.remove_node(node_id) start_at=start_at,
inputs=inputs,
# move to next iteration outputs={"output": jsonable_encoder(outputs)},
current_index = variable_pool.get([self.node_id, "index"]) steps=len(iterator_list_value),
if current_index is None: metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
raise ValueError(f"iteration {self.node_id} current index not found")
next_index = int(current_index.to_object()) + 1
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output)
if current_iteration_output
else None,
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error, error=event.error,
) )
)
break yield RunCompletedEvent(
else: run_result=NodeRunResult(
event = cast(InNodeEvent, event) status=WorkflowNodeExecutionStatus.FAILED,
yield event error=event.error,
)
)
return
else:
event = cast(InNodeEvent, event)
yield event
# append to iteration output variable list
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
outputs.append(current_iteration_output)
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove_node(node_id)
# move to next iteration
current_index = variable_pool.get([self.node_id, "index"])
if current_index is None:
raise ValueError(f"iteration {self.node_id} current index not found")
next_index = int(current_index.to_object()) + 1
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output)
if current_iteration_output
else None,
)
yield IterationRunSucceededEvent( yield IterationRunSucceededEvent(
iteration_id=self.id, iteration_id=self.id,

Loading…
Cancel
Save