|
|
|
@ -64,28 +64,28 @@ class IterationNode(BaseNode):
|
|
|
|
|
|
|
|
|
|
|
|
_node_type = NodeType.ITERATION
|
|
|
|
_node_type = NodeType.ITERATION
|
|
|
|
|
|
|
|
|
|
|
|
node_data: IterationNodeData
|
|
|
|
_node_data: IterationNodeData
|
|
|
|
|
|
|
|
|
|
|
|
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
|
|
|
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
|
|
|
self.node_data = IterationNodeData.model_validate(data)
|
|
|
|
self._node_data = IterationNodeData.model_validate(data)
|
|
|
|
|
|
|
|
|
|
|
|
def get_error_strategy(self) -> Optional[ErrorStrategy]:
|
|
|
|
def get_error_strategy(self) -> Optional[ErrorStrategy]:
|
|
|
|
return self.node_data.error_strategy
|
|
|
|
return self._node_data.error_strategy
|
|
|
|
|
|
|
|
|
|
|
|
def get_retry_config(self) -> RetryConfig:
|
|
|
|
def get_retry_config(self) -> RetryConfig:
|
|
|
|
return self.node_data.retry_config
|
|
|
|
return self._node_data.retry_config
|
|
|
|
|
|
|
|
|
|
|
|
def get_title(self) -> str:
|
|
|
|
def get_title(self) -> str:
|
|
|
|
return self.node_data.title
|
|
|
|
return self._node_data.title
|
|
|
|
|
|
|
|
|
|
|
|
def get_description(self) -> Optional[str]:
|
|
|
|
def get_description(self) -> Optional[str]:
|
|
|
|
return self.node_data.desc
|
|
|
|
return self._node_data.desc
|
|
|
|
|
|
|
|
|
|
|
|
def get_default_value_dict(self) -> dict[str, Any]:
|
|
|
|
def get_default_value_dict(self) -> dict[str, Any]:
|
|
|
|
return self.node_data.default_value_dict
|
|
|
|
return self._node_data.default_value_dict
|
|
|
|
|
|
|
|
|
|
|
|
def get_base_node_data(self) -> BaseNodeData:
|
|
|
|
def get_base_node_data(self) -> BaseNodeData:
|
|
|
|
return self.node_data
|
|
|
|
return self._node_data
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
|
|
|
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
|
|
|
@ -106,10 +106,10 @@ class IterationNode(BaseNode):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Run the node.
|
|
|
|
Run the node.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
|
|
|
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
|
|
|
|
|
|
|
|
|
|
|
if not variable:
|
|
|
|
if not variable:
|
|
|
|
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
|
|
|
|
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
|
|
|
|
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
|
|
|
|
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
|
|
|
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
|
|
|
@ -139,10 +139,10 @@ class IterationNode(BaseNode):
|
|
|
|
|
|
|
|
|
|
|
|
graph_config = self.graph_config
|
|
|
|
graph_config = self.graph_config
|
|
|
|
|
|
|
|
|
|
|
|
if not self.node_data.start_node_id:
|
|
|
|
if not self._node_data.start_node_id:
|
|
|
|
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
|
|
|
|
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
|
|
|
|
|
|
|
|
|
|
|
|
root_node_id = self.node_data.start_node_id
|
|
|
|
root_node_id = self._node_data.start_node_id
|
|
|
|
|
|
|
|
|
|
|
|
# init graph
|
|
|
|
# init graph
|
|
|
|
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
|
|
|
|
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
|
|
|
|
@ -185,7 +185,7 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
start_at=start_at,
|
|
|
|
start_at=start_at,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs=inputs,
|
|
|
|
metadata={"iterator_length": len(iterator_list_value)},
|
|
|
|
metadata={"iterator_length": len(iterator_list_value)},
|
|
|
|
@ -196,7 +196,7 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
index=0,
|
|
|
|
index=0,
|
|
|
|
pre_iteration_output=None,
|
|
|
|
pre_iteration_output=None,
|
|
|
|
duration=None,
|
|
|
|
duration=None,
|
|
|
|
@ -204,11 +204,11 @@ class IterationNode(BaseNode):
|
|
|
|
iter_run_map: dict[str, float] = {}
|
|
|
|
iter_run_map: dict[str, float] = {}
|
|
|
|
outputs: list[Any] = [None] * len(iterator_list_value)
|
|
|
|
outputs: list[Any] = [None] * len(iterator_list_value)
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
if self.node_data.is_parallel:
|
|
|
|
if self._node_data.is_parallel:
|
|
|
|
futures: list[Future] = []
|
|
|
|
futures: list[Future] = []
|
|
|
|
q: Queue = Queue()
|
|
|
|
q: Queue = Queue()
|
|
|
|
thread_pool = GraphEngineThreadPool(
|
|
|
|
thread_pool = GraphEngineThreadPool(
|
|
|
|
max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
|
|
|
|
max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
|
|
|
|
)
|
|
|
|
)
|
|
|
|
for index, item in enumerate(iterator_list_value):
|
|
|
|
for index, item in enumerate(iterator_list_value):
|
|
|
|
future: Future = thread_pool.submit(
|
|
|
|
future: Future = thread_pool.submit(
|
|
|
|
@ -265,7 +265,7 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_graph=iteration_graph,
|
|
|
|
iteration_graph=iteration_graph,
|
|
|
|
iter_run_map=iter_run_map,
|
|
|
|
iter_run_map=iter_run_map,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
|
|
|
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
|
|
|
outputs = [output for output in outputs if output is not None]
|
|
|
|
outputs = [output for output in outputs if output is not None]
|
|
|
|
|
|
|
|
|
|
|
|
# Flatten the list of lists
|
|
|
|
# Flatten the list of lists
|
|
|
|
@ -277,7 +277,7 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
start_at=start_at,
|
|
|
|
start_at=start_at,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs=inputs,
|
|
|
|
outputs={"output": outputs},
|
|
|
|
outputs={"output": outputs},
|
|
|
|
@ -302,7 +302,7 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
start_at=start_at,
|
|
|
|
start_at=start_at,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs=inputs,
|
|
|
|
outputs={"output": outputs},
|
|
|
|
outputs={"output": outputs},
|
|
|
|
@ -394,7 +394,7 @@ class IterationNode(BaseNode):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if not isinstance(event, BaseNodeEvent):
|
|
|
|
if not isinstance(event, BaseNodeEvent):
|
|
|
|
return event
|
|
|
|
return event
|
|
|
|
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
|
|
|
if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
|
|
|
event.parallel_mode_run_id = parallel_mode_run_id
|
|
|
|
event.parallel_mode_run_id = parallel_mode_run_id
|
|
|
|
|
|
|
|
|
|
|
|
iter_metadata = {
|
|
|
|
iter_metadata = {
|
|
|
|
@ -457,12 +457,12 @@ class IterationNode(BaseNode):
|
|
|
|
elif isinstance(event, BaseGraphEvent):
|
|
|
|
elif isinstance(event, BaseGraphEvent):
|
|
|
|
if isinstance(event, GraphRunFailedEvent):
|
|
|
|
if isinstance(event, GraphRunFailedEvent):
|
|
|
|
# iteration run failed
|
|
|
|
# iteration run failed
|
|
|
|
if self.node_data.is_parallel:
|
|
|
|
if self._node_data.is_parallel:
|
|
|
|
yield IterationRunFailedEvent(
|
|
|
|
yield IterationRunFailedEvent(
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
|
start_at=start_at,
|
|
|
|
start_at=start_at,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs=inputs,
|
|
|
|
@ -476,7 +476,7 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
start_at=start_at,
|
|
|
|
start_at=start_at,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs=inputs,
|
|
|
|
outputs={"output": outputs},
|
|
|
|
outputs={"output": outputs},
|
|
|
|
@ -497,7 +497,7 @@ class IterationNode(BaseNode):
|
|
|
|
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
|
|
|
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if isinstance(event, NodeRunFailedEvent):
|
|
|
|
if isinstance(event, NodeRunFailedEvent):
|
|
|
|
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
|
|
|
if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
|
|
|
yield NodeInIterationFailedEvent(
|
|
|
|
yield NodeInIterationFailedEvent(
|
|
|
|
**metadata_event.model_dump(),
|
|
|
|
**metadata_event.model_dump(),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
@ -511,14 +511,14 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
index=next_index,
|
|
|
|
index=next_index,
|
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
|
pre_iteration_output=None,
|
|
|
|
pre_iteration_output=None,
|
|
|
|
duration=duration,
|
|
|
|
duration=duration,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
|
|
|
elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
|
|
|
yield NodeInIterationFailedEvent(
|
|
|
|
yield NodeInIterationFailedEvent(
|
|
|
|
**metadata_event.model_dump(),
|
|
|
|
**metadata_event.model_dump(),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
@ -532,14 +532,14 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
index=next_index,
|
|
|
|
index=next_index,
|
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
|
pre_iteration_output=None,
|
|
|
|
pre_iteration_output=None,
|
|
|
|
duration=duration,
|
|
|
|
duration=duration,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
|
|
|
elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
|
|
|
yield NodeInIterationFailedEvent(
|
|
|
|
yield NodeInIterationFailedEvent(
|
|
|
|
**metadata_event.model_dump(),
|
|
|
|
**metadata_event.model_dump(),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
@ -550,12 +550,12 @@ class IterationNode(BaseNode):
|
|
|
|
variable_pool.remove([node_id])
|
|
|
|
variable_pool.remove([node_id])
|
|
|
|
|
|
|
|
|
|
|
|
# iteration run failed
|
|
|
|
# iteration run failed
|
|
|
|
if self.node_data.is_parallel:
|
|
|
|
if self._node_data.is_parallel:
|
|
|
|
yield IterationRunFailedEvent(
|
|
|
|
yield IterationRunFailedEvent(
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
|
start_at=start_at,
|
|
|
|
start_at=start_at,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs=inputs,
|
|
|
|
@ -569,7 +569,7 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
start_at=start_at,
|
|
|
|
start_at=start_at,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs=inputs,
|
|
|
|
outputs={"output": outputs},
|
|
|
|
outputs={"output": outputs},
|
|
|
|
@ -588,7 +588,7 @@ class IterationNode(BaseNode):
|
|
|
|
return
|
|
|
|
return
|
|
|
|
yield metadata_event
|
|
|
|
yield metadata_event
|
|
|
|
|
|
|
|
|
|
|
|
current_output_segment = variable_pool.get(self.node_data.output_selector)
|
|
|
|
current_output_segment = variable_pool.get(self._node_data.output_selector)
|
|
|
|
if current_output_segment is None:
|
|
|
|
if current_output_segment is None:
|
|
|
|
raise IterationNodeError("iteration output selector not found")
|
|
|
|
raise IterationNodeError("iteration output selector not found")
|
|
|
|
current_iteration_output = current_output_segment.value
|
|
|
|
current_iteration_output = current_output_segment.value
|
|
|
|
@ -608,7 +608,7 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
index=next_index,
|
|
|
|
index=next_index,
|
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
|
pre_iteration_output=current_iteration_output or None,
|
|
|
|
pre_iteration_output=current_iteration_output or None,
|
|
|
|
@ -621,7 +621,7 @@ class IterationNode(BaseNode):
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_id=self.id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_id=self.node_id,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_type=self.node_type,
|
|
|
|
iteration_node_data=self.node_data,
|
|
|
|
iteration_node_data=self._node_data,
|
|
|
|
start_at=start_at,
|
|
|
|
start_at=start_at,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs=inputs,
|
|
|
|
outputs={"output": None},
|
|
|
|
outputs={"output": None},
|
|
|
|
|