chore(list_operator): refine exception handling for error specificity (#10206)

pull/12372/head
-LAN- 1 year ago committed by Joel
parent ada7f5c30f
commit 762dec2dc4

@ -0,0 +1,16 @@
class ListOperatorError(ValueError):
"""Base class for all ListOperator errors."""
pass
class InvalidFilterValueError(ListOperatorError):
pass
class InvalidKeyError(ListOperatorError):
pass
class InvalidConditionError(ListOperatorError):
pass

@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Literal from typing import Literal, Union
from core.file import File from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from .entities import ListOperatorNodeData from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
class ListOperatorNode(BaseNode[ListOperatorNodeData]): class ListOperatorNode(BaseNode[ListOperatorNodeData]):
@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
) )
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = ( error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment" "or ArrayStringSegment"
@ -36,70 +47,98 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
) )
if isinstance(variable, ArrayFileSegment): if isinstance(variable, ArrayFileSegment):
inputs = {"variable": [item.to_dict() for item in variable.value]}
process_data["variable"] = [item.to_dict() for item in variable.value] process_data["variable"] = [item.to_dict() for item in variable.value]
else: else:
inputs = {"variable": variable.value}
process_data["variable"] = variable.value process_data["variable"] = variable.value
# Filter try:
if self.node_data.filter_by.enabled: # Filter
for condition in self.node_data.filter_by.conditions: if self.node_data.filter_by.enabled:
if isinstance(variable, ArrayStringSegment): variable = self._apply_filter(variable)
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}") # Order
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text if self.node_data.order_by.enabled:
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) variable = self._apply_order(variable)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result}) # Slice
elif isinstance(variable, ArrayNumberSegment): if self.node_data.limit.enabled:
if not isinstance(condition.value, str): variable = self._apply_slice(variable)
raise ValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text outputs = {
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) "result": variable.value,
result = list(filter(filter_func, variable.value)) "first_record": variable.value[0] if variable.value else None,
variable = variable.model_copy(update={"value": result}) "last_record": variable.value[-1] if variable.value else None,
elif isinstance(variable, ArrayFileSegment): }
if isinstance(condition.value, str): return NodeRunResult(
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text status=WorkflowNodeExecutionStatus.SUCCEEDED,
else: inputs=inputs,
value = condition.value process_data=process_data,
filter_func = _get_file_filter_func( outputs=outputs,
key=condition.key, )
condition=condition.comparison_operator, except ListOperatorError as e:
value=value, return NodeRunResult(
) status=WorkflowNodeExecutionStatus.FAILED,
result = list(filter(filter_func, variable.value)) error=str(e),
variable = variable.model_copy(update={"value": result}) inputs=inputs,
process_data=process_data,
# Order outputs=outputs,
if self.node_data.order_by.enabled: )
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment): if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value) if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment): elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value) if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment): elif isinstance(variable, ArrayFileSegment):
result = _order_file( if isinstance(condition.value, str):
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
else:
value = condition.value
filter_func = _get_file_filter_func(
key=condition.key,
condition=condition.comparison_operator,
value=value,
) )
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
return variable
# Slice def _apply_order(
if self.node_data.limit.enabled: self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
result = variable.value[: self.node_data.limit.size] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
return variable
outputs = { def _apply_slice(
"result": variable.value, self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
"first_record": variable.value[0] if variable.value else None, ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
"last_record": variable.value[-1] if variable.value else None, result = variable.value[: self.node_data.limit.size]
} return variable.model_copy(update={"value": result})
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
case "size": case "size":
return lambda x: x.size return lambda x: x.size
case _: case _:
raise ValueError(f"Invalid key: {key}") raise InvalidKeyError(f"Invalid key: {key}")
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
@ -125,7 +164,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
case "url": case "url":
return lambda x: x.remote_url or "" return lambda x: x.remote_url or ""
case _: case _:
raise ValueError(f"Invalid key: {key}") raise InvalidKeyError(f"Invalid key: {key}")
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "not empty": case "not empty":
return lambda x: x != "" return lambda x: x != ""
case _: case _:
raise ValueError(f"Invalid condition: {condition}") raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "not in": case "not in":
return lambda x: not _in(value)(x) return lambda x: not _in(value)(x)
case _: case _:
raise ValueError(f"Invalid condition: {condition}") raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
case "": case "":
return _ge(value) return _ge(value)
case _: case _:
raise ValueError(f"Invalid condition: {condition}") raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_number_func(key=key) extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else: else:
raise ValueError(f"Invalid key: {key}") raise InvalidKeyError(f"Invalid key: {key}")
def _contains(value: str): def _contains(value: str):

Loading…
Cancel
Save