refactor(nodes): rename `node_data` to `_node_data`

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 956f9158c7
commit ba40054f44
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -63,28 +63,28 @@ class AgentNode(BaseNode):
""" """
_node_type = NodeType.AGENT _node_type = NodeType.AGENT
node_data: AgentNodeData _node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = AgentNodeData.model_validate(data) self._node_data = AgentNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -94,7 +94,7 @@ class AgentNode(BaseNode):
""" """
Run the agent node Run the agent node
""" """
node_data = cast(AgentNodeData, self.node_data) node_data = cast(AgentNodeData, self._node_data)
try: try:
strategy = get_plugin_agent_strategy( strategy = get_plugin_agent_strategy(
@ -160,18 +160,18 @@ class AgentNode(BaseNode):
type=ToolInvokeMessage.MessageType.LOG, type=ToolInvokeMessage.MessageType.LOG,
message=ToolInvokeMessage.LogMessage( message=ToolInvokeMessage.LogMessage(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}", label=f"Agent Strategy: {cast(AgentNodeData, self._node_data).agent_strategy_name}",
parent_id=None, parent_id=None,
error=None, error=None,
status=ToolInvokeMessage.LogMessage.LogStatus.START, status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={ data={
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, "strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
"parameters": parameters_for_log, "parameters": parameters_for_log,
"thought_process": "Agent strategy execution started", "thought_process": "Agent strategy execution started",
}, },
metadata={ metadata={
"icon": self.agent_strategy_icon, "icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
}, },
), ),
) )
@ -180,7 +180,7 @@ class AgentNode(BaseNode):
messages=message_stream, messages=message_stream,
tool_info={ tool_info={
"icon": self.agent_strategy_icon, "icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
}, },
parameters_for_log=parameters_for_log, parameters_for_log=parameters_for_log,
user_id=self.user_id, user_id=self.user_id,
@ -299,7 +299,7 @@ class AgentNode(BaseNode):
) )
extra = tool.get("extra", {}) extra = tool.get("extra", {})
runtime_variable_pool = variable_pool if self.node_data.version != "1" else None runtime_variable_pool = variable_pool if self._node_data.version != "1" else None
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
) )
@ -415,7 +415,7 @@ class AgentNode(BaseNode):
plugin plugin
for plugin in plugins for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self.node_data).agent_strategy_provider_name == cast(AgentNodeData, self._node_data).agent_strategy_provider_name
) )
icon = current_plugin.declaration.icon icon = current_plugin.declaration.icon
except StopIteration: except StopIteration:

@ -20,28 +20,28 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode): class AnswerNode(BaseNode):
_node_type = NodeType.ANSWER _node_type = NodeType.ANSWER
node_data: AnswerNodeData _node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = AnswerNodeData.model_validate(data) self._node_data = AnswerNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -53,7 +53,7 @@ class AnswerNode(BaseNode):
:return: :return:
""" """
# generate routes # generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data) generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
answer = "" answer = ""
files = [] files = []

@ -25,28 +25,28 @@ from .exc import (
class CodeNode(BaseNode): class CodeNode(BaseNode):
_node_type = NodeType.CODE _node_type = NodeType.CODE
node_data: CodeNodeData _node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = CodeNodeData.model_validate(data) self._node_data = CodeNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -70,12 +70,12 @@ class CodeNode(BaseNode):
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get code language # Get code language
code_language = self.node_data.code_language code_language = self._node_data.code_language
code = self.node_data.code code = self._node_data.code
# Get variables # Get variables
variables = {} variables = {}
for variable_selector in self.node_data.variables: for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment): if isinstance(variable, ArrayFileSegment):
@ -91,7 +91,7 @@ class CodeNode(BaseNode):
) )
# Transform result # Transform result
result = self._transform_result(result=result, output_schema=self.node_data.outputs) result = self._transform_result(result=result, output_schema=self._node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e: except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -369,8 +369,8 @@ class CodeNode(BaseNode):
@property @property
def continue_on_error(self) -> bool: def continue_on_error(self) -> bool:
return self.node_data.error_strategy is not None return self._node_data.error_strategy is not None
@property @property
def retry(self) -> bool: def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled return self._node_data.retry_config.retry_enabled

@ -45,35 +45,35 @@ class DocumentExtractorNode(BaseNode):
_node_type = NodeType.DOCUMENT_EXTRACTOR _node_type = NodeType.DOCUMENT_EXTRACTOR
node_data: DocumentExtractorNodeData _node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = DocumentExtractorNodeData.model_validate(data) self._node_data = DocumentExtractorNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
def _run(self): def _run(self):
variable_selector = self.node_data.variable_selector variable_selector = self._node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None: if variable is None:

@ -12,28 +12,28 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType
class EndNode(BaseNode): class EndNode(BaseNode):
_node_type = NodeType.END _node_type = NodeType.END
node_data: EndNodeData _node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = EndNodeData(**data) self._node_data = EndNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -44,7 +44,7 @@ class EndNode(BaseNode):
Run node Run node
:return: :return:
""" """
output_variables = self.node_data.outputs output_variables = self._node_data.outputs
outputs = {} outputs = {}
for variable_selector in output_variables: for variable_selector in output_variables:

@ -36,28 +36,28 @@ logger = logging.getLogger(__name__)
class HttpRequestNode(BaseNode): class HttpRequestNode(BaseNode):
_node_type = NodeType.HTTP_REQUEST _node_type = NodeType.HTTP_REQUEST
node_data: HttpRequestNodeData _node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = HttpRequestNodeData.model_validate(data) self._node_data = HttpRequestNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
@ -92,8 +92,8 @@ class HttpRequestNode(BaseNode):
process_data = {} process_data = {}
try: try:
http_executor = Executor( http_executor = Executor(
node_data=self.node_data, node_data=self._node_data,
timeout=self._get_request_timeout(self.node_data), timeout=self._get_request_timeout(self._node_data),
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0, max_retries=0,
) )
@ -246,8 +246,8 @@ class HttpRequestNode(BaseNode):
@property @property
def continue_on_error(self) -> bool: def continue_on_error(self) -> bool:
return self.node_data.error_strategy is not None return self._node_data.error_strategy is not None
@property @property
def retry(self) -> bool: def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled return self._node_data.retry_config.retry_enabled

@ -17,28 +17,28 @@ from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(BaseNode): class IfElseNode(BaseNode):
_node_type = NodeType.IF_ELSE _node_type = NodeType.IF_ELSE
node_data: IfElseNodeData _node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = IfElseNodeData.model_validate(data) self._node_data = IfElseNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -59,8 +59,8 @@ class IfElseNode(BaseNode):
condition_processor = ConditionProcessor() condition_processor = ConditionProcessor()
try: try:
# Check if the new cases structure is used # Check if the new cases structure is used
if self.node_data.cases: if self._node_data.cases:
for case in self.node_data.cases: for case in self._node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions( input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions, conditions=case.conditions,
@ -86,8 +86,8 @@ class IfElseNode(BaseNode):
input_conditions, group_result, final_result = _should_not_use_old_function( input_conditions, group_result, final_result = _should_not_use_old_function(
condition_processor=condition_processor, condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
conditions=self.node_data.conditions or [], conditions=self._node_data.conditions or [],
operator=self.node_data.logical_operator or "and", operator=self._node_data.logical_operator or "and",
) )
selected_case_id = "true" if final_result else "false" selected_case_id = "true" if final_result else "false"

@ -64,28 +64,28 @@ class IterationNode(BaseNode):
_node_type = NodeType.ITERATION _node_type = NodeType.ITERATION
node_data: IterationNodeData _node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = IterationNodeData.model_validate(data) self._node_data = IterationNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -106,10 +106,10 @@ class IterationNode(BaseNode):
""" """
Run the node. Run the node.
""" """
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable: if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
@ -139,10 +139,10 @@ class IterationNode(BaseNode):
graph_config = self.graph_config graph_config = self.graph_config
if not self.node_data.start_node_id: if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self.node_data.start_node_id root_node_id = self._node_data.start_node_id
# init graph # init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
@ -185,7 +185,7 @@ class IterationNode(BaseNode):
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
metadata={"iterator_length": len(iterator_list_value)}, metadata={"iterator_length": len(iterator_list_value)},
@ -196,7 +196,7 @@ class IterationNode(BaseNode):
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
index=0, index=0,
pre_iteration_output=None, pre_iteration_output=None,
duration=None, duration=None,
@ -204,11 +204,11 @@ class IterationNode(BaseNode):
iter_run_map: dict[str, float] = {} iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value) outputs: list[Any] = [None] * len(iterator_list_value)
try: try:
if self.node_data.is_parallel: if self._node_data.is_parallel:
futures: list[Future] = [] futures: list[Future] = []
q: Queue = Queue() q: Queue = Queue()
thread_pool = GraphEngineThreadPool( thread_pool = GraphEngineThreadPool(
max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
) )
for index, item in enumerate(iterator_list_value): for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit( future: Future = thread_pool.submit(
@ -265,7 +265,7 @@ class IterationNode(BaseNode):
iteration_graph=iteration_graph, iteration_graph=iteration_graph,
iter_run_map=iter_run_map, iter_run_map=iter_run_map,
) )
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None] outputs = [output for output in outputs if output is not None]
# Flatten the list of lists # Flatten the list of lists
@ -277,7 +277,7 @@ class IterationNode(BaseNode):
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": outputs},
@ -302,7 +302,7 @@ class IterationNode(BaseNode):
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": outputs},
@ -394,7 +394,7 @@ class IterationNode(BaseNode):
""" """
if not isinstance(event, BaseNodeEvent): if not isinstance(event, BaseNodeEvent):
return event return event
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id event.parallel_mode_run_id = parallel_mode_run_id
iter_metadata = { iter_metadata = {
@ -457,12 +457,12 @@ class IterationNode(BaseNode):
elif isinstance(event, BaseGraphEvent): elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent): if isinstance(event, GraphRunFailedEvent):
# iteration run failed # iteration run failed
if self.node_data.is_parallel: if self._node_data.is_parallel:
yield IterationRunFailedEvent( yield IterationRunFailedEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
@ -476,7 +476,7 @@ class IterationNode(BaseNode):
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": outputs},
@ -497,7 +497,7 @@ class IterationNode(BaseNode):
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
) )
if isinstance(event, NodeRunFailedEvent): if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent( yield NodeInIterationFailedEvent(
**metadata_event.model_dump(), **metadata_event.model_dump(),
) )
@ -511,14 +511,14 @@ class IterationNode(BaseNode):
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
index=next_index, index=next_index,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None, pre_iteration_output=None,
duration=duration, duration=duration,
) )
return return
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent( yield NodeInIterationFailedEvent(
**metadata_event.model_dump(), **metadata_event.model_dump(),
) )
@ -532,14 +532,14 @@ class IterationNode(BaseNode):
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
index=next_index, index=next_index,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None, pre_iteration_output=None,
duration=duration, duration=duration,
) )
return return
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
yield NodeInIterationFailedEvent( yield NodeInIterationFailedEvent(
**metadata_event.model_dump(), **metadata_event.model_dump(),
) )
@ -550,12 +550,12 @@ class IterationNode(BaseNode):
variable_pool.remove([node_id]) variable_pool.remove([node_id])
# iteration run failed # iteration run failed
if self.node_data.is_parallel: if self._node_data.is_parallel:
yield IterationRunFailedEvent( yield IterationRunFailedEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
@ -569,7 +569,7 @@ class IterationNode(BaseNode):
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": outputs},
@ -588,7 +588,7 @@ class IterationNode(BaseNode):
return return
yield metadata_event yield metadata_event
current_output_segment = variable_pool.get(self.node_data.output_selector) current_output_segment = variable_pool.get(self._node_data.output_selector)
if current_output_segment is None: if current_output_segment is None:
raise IterationNodeError("iteration output selector not found") raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value current_iteration_output = current_output_segment.value
@ -608,7 +608,7 @@ class IterationNode(BaseNode):
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
index=next_index, index=next_index,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=current_iteration_output or None, pre_iteration_output=current_iteration_output or None,
@ -621,7 +621,7 @@ class IterationNode(BaseNode):
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": None}, outputs={"output": None},

@ -86,7 +86,7 @@ default_retrieval_model = {
class KnowledgeRetrievalNode(BaseNode): class KnowledgeRetrievalNode(BaseNode):
_node_type = NodeType.KNOWLEDGE_RETRIEVAL _node_type = NodeType.KNOWLEDGE_RETRIEVAL
node_data: KnowledgeRetrievalNodeData _node_data: KnowledgeRetrievalNodeData
# Instance attributes specific to LLMNode. # Instance attributes specific to LLMNode.
# Output variable for file # Output variable for file
@ -126,32 +126,32 @@ class KnowledgeRetrievalNode(BaseNode):
self._llm_file_saver = llm_file_saver self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = KnowledgeRetrievalNodeData.model_validate(data) self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls): def version(cls):
return "1" return "1"
def _run(self) -> NodeRunResult: # type: ignore def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeRetrievalNodeData, self.node_data) node_data = cast(KnowledgeRetrievalNodeData, self._node_data)
# extract variables # extract variables
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector) variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
if not isinstance(variable, StringSegment): if not isinstance(variable, StringSegment):
@ -545,7 +545,7 @@ class KnowledgeRetrievalNode(BaseNode):
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop, stop=stop,
user_id=self.user_id, user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled, structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=None, structured_output=None,
file_saver=self._llm_file_saver, file_saver=self._llm_file_saver,
file_outputs=self._file_outputs, file_outputs=self._file_outputs,

@ -17,28 +17,28 @@ from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError
class ListOperatorNode(BaseNode): class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR _node_type = NodeType.LIST_OPERATOR
node_data: ListOperatorNodeData _node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ListOperatorNodeData(**data) self._node_data = ListOperatorNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -49,9 +49,9 @@ class ListOperatorNode(BaseNode):
process_data: dict[str, list] = {} process_data: dict[str, list] = {}
outputs: dict[str, Any] = {} outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
if variable is None: if variable is None:
error_message = f"Variable not found for selector: {self.node_data.variable}" error_message = f"Variable not found for selector: {self._node_data.variable}"
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
) )
@ -71,7 +71,7 @@ class ListOperatorNode(BaseNode):
) )
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = ( error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment" "or ArrayStringSegment"
) )
return NodeRunResult( return NodeRunResult(
@ -87,19 +87,19 @@ class ListOperatorNode(BaseNode):
try: try:
# Filter # Filter
if self.node_data.filter_by.enabled: if self._node_data.filter_by.enabled:
variable = self._apply_filter(variable) variable = self._apply_filter(variable)
# Extract # Extract
if self.node_data.extract_by.enabled: if self._node_data.extract_by.enabled:
variable = self._extract_slice(variable) variable = self._extract_slice(variable)
# Order # Order
if self.node_data.order_by.enabled: if self._node_data.order_by.enabled:
variable = self._apply_order(variable) variable = self._apply_order(variable)
# Slice # Slice
if self.node_data.limit.enabled: if self._node_data.limit.enabled:
variable = self._apply_slice(variable) variable = self._apply_slice(variable)
outputs = { outputs = {
@ -127,7 +127,7 @@ class ListOperatorNode(BaseNode):
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
filter_func: Callable[[Any], bool] filter_func: Callable[[Any], bool]
result: list[Any] = [] result: list[Any] = []
for condition in self.node_data.filter_by.conditions: for condition in self._node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment): if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str): if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
@ -160,14 +160,14 @@ class ListOperatorNode(BaseNode):
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment): if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value) result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment): elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value) result = _order_number(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment): elif isinstance(variable, ArrayFileSegment):
result = _order_file( result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
) )
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
return variable return variable
@ -175,13 +175,13 @@ class ListOperatorNode(BaseNode):
def _apply_slice( def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
result = variable.value[: self.node_data.limit.size] result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result}) return variable.model_copy(update={"value": result})
def _extract_slice( def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1: if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}") raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
value -= 1 value -= 1

@ -99,7 +99,7 @@ logger = logging.getLogger(__name__)
class LLMNode(BaseNode): class LLMNode(BaseNode):
_node_type = NodeType.LLM _node_type = NodeType.LLM
node_data: LLMNodeData _node_data: LLMNodeData
# Instance attributes specific to LLMNode. # Instance attributes specific to LLMNode.
# Output variable for file # Output variable for file
@ -139,25 +139,25 @@ class LLMNode(BaseNode):
self._llm_file_saver = llm_file_saver self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LLMNodeData.model_validate(data) self._node_data = LLMNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -173,13 +173,13 @@ class LLMNode(BaseNode):
try: try:
# init messages template # init messages template
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
# fetch variables and fetch values from variable pool # fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data) inputs = self._fetch_inputs(node_data=self._node_data)
# fetch jinja2 inputs # fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
# merge inputs # merge inputs
inputs.update(jinja_inputs) inputs.update(jinja_inputs)
@ -190,9 +190,9 @@ class LLMNode(BaseNode):
files = ( files = (
llm_utils.fetch_files( llm_utils.fetch_files(
variable_pool=variable_pool, variable_pool=variable_pool,
selector=self.node_data.vision.configs.variable_selector, selector=self._node_data.vision.configs.variable_selector,
) )
if self.node_data.vision.enabled if self._node_data.vision.enabled
else [] else []
) )
@ -200,7 +200,7 @@ class LLMNode(BaseNode):
node_inputs["#files#"] = [file.to_dict() for file in files] node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value # fetch context value
generator = self._fetch_context(node_data=self.node_data) generator = self._fetch_context(node_data=self._node_data)
context = None context = None
for event in generator: for event in generator:
if isinstance(event, RunRetrieverResourceEvent): if isinstance(event, RunRetrieverResourceEvent):
@ -211,7 +211,7 @@ class LLMNode(BaseNode):
# fetch model config # fetch model config
model_instance, model_config = LLMNode._fetch_model_config( model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self.node_data.model, node_data_model=self._node_data.model,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
) )
@ -219,13 +219,13 @@ class LLMNode(BaseNode):
memory = llm_utils.fetch_memory( memory = llm_utils.fetch_memory(
variable_pool=variable_pool, variable_pool=variable_pool,
app_id=self.app_id, app_id=self.app_id,
node_data_memory=self.node_data.memory, node_data_memory=self._node_data.memory,
model_instance=model_instance, model_instance=model_instance,
) )
query = None query = None
if self.node_data.memory: if self._node_data.memory:
query = self.node_data.memory.query_prompt_template query = self._node_data.memory.query_prompt_template
if not query and ( if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
): ):
@ -237,24 +237,24 @@ class LLMNode(BaseNode):
context=context, context=context,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
prompt_template=self.node_data.prompt_template, prompt_template=self._node_data.prompt_template,
memory_config=self.node_data.memory, memory_config=self._node_data.memory,
vision_enabled=self.node_data.vision.enabled, vision_enabled=self._node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail, vision_detail=self._node_data.vision.configs.detail,
variable_pool=variable_pool, variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables, jinja2_variables=self._node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
) )
# handle invoke result # handle invoke result
generator = LLMNode.invoke_llm( generator = LLMNode.invoke_llm(
node_data_model=self.node_data.model, node_data_model=self._node_data.model,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop, stop=stop,
user_id=self.user_id, user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled, structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self.node_data.structured_output, structured_output=self._node_data.structured_output,
file_saver=self._llm_file_saver, file_saver=self._llm_file_saver,
file_outputs=self._file_outputs, file_outputs=self._file_outputs,
node_id=self.node_id, node_id=self.node_id,
@ -1010,7 +1010,7 @@ class LLMNode(BaseNode):
""" """
Fetch model schema Fetch model schema
""" """
model_name = self.node_data.model.name model_name = self._node_data.model.name
model_manager = ModelManager() model_manager = ModelManager()
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
@ -1089,11 +1089,11 @@ class LLMNode(BaseNode):
@property @property
def continue_on_error(self) -> bool: def continue_on_error(self) -> bool:
return self.node_data.error_strategy is not None return self._node_data.error_strategy is not None
@property @property
def retry(self) -> bool: def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled return self._node_data.retry_config.retry_enabled
def _combine_message_content_with_role( def _combine_message_content_with_role(

@ -16,28 +16,28 @@ class LoopEndNode(BaseNode):
_node_type = NodeType.LOOP_END _node_type = NodeType.LOOP_END
node_data: LoopEndNodeData _node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopEndNodeData(**data) self._node_data = LoopEndNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:

@ -51,28 +51,28 @@ class LoopNode(BaseNode):
_node_type = NodeType.LOOP _node_type = NodeType.LOOP
node_data: LoopNodeData _node_data: LoopNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopNodeData.model_validate(data) self._node_data = LoopNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -81,17 +81,17 @@ class LoopNode(BaseNode):
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node.""" """Run the node."""
# Get inputs # Get inputs
loop_count = self.node_data.loop_count loop_count = self._node_data.loop_count
break_conditions = self.node_data.break_conditions break_conditions = self._node_data.break_conditions
logical_operator = self.node_data.logical_operator logical_operator = self._node_data.logical_operator
inputs = {"loop_count": loop_count} inputs = {"loop_count": loop_count}
if not self.node_data.start_node_id: if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self.node_id} not found") raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph # Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id) loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph: if not loop_graph:
raise ValueError("loop graph not found") raise ValueError("loop graph not found")
@ -101,8 +101,8 @@ class LoopNode(BaseNode):
# Initialize loop variables # Initialize loop variables
loop_variable_selectors = {} loop_variable_selectors = {}
if self.node_data.loop_variables: if self._node_data.loop_variables:
for loop_variable in self.node_data.loop_variables: for loop_variable in self._node_data.loop_variables:
value_processor = { value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value), "constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value), "variable": lambda var=loop_variable: variable_pool.get(var.value),
@ -151,7 +151,7 @@ class LoopNode(BaseNode):
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.node_type,
loop_node_data=self.node_data, loop_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
metadata={"loop_length": loop_count}, metadata={"loop_length": loop_count},
@ -208,10 +208,10 @@ class LoopNode(BaseNode):
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.node_type,
loop_node_data=self.node_data, loop_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs=self.node_data.outputs, outputs=self._node_data.outputs,
steps=loop_count, steps=loop_count,
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
@ -229,7 +229,7 @@ class LoopNode(BaseNode):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
}, },
outputs=self.node_data.outputs, outputs=self._node_data.outputs,
inputs=inputs, inputs=inputs,
) )
) )
@ -241,7 +241,7 @@ class LoopNode(BaseNode):
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.node_type,
loop_node_data=self.node_data, loop_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
steps=loop_count, steps=loop_count,
@ -344,7 +344,7 @@ class LoopNode(BaseNode):
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.node_type,
loop_node_data=self.node_data, loop_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
steps=current_index, steps=current_index,
@ -375,7 +375,7 @@ class LoopNode(BaseNode):
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.node_type,
loop_node_data=self.node_data, loop_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
steps=current_index, steps=current_index,
@ -411,7 +411,7 @@ class LoopNode(BaseNode):
_outputs[loop_variable_key] = None _outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1 _outputs["loop_round"] = current_index + 1
self.node_data.outputs = _outputs self._node_data.outputs = _outputs
if check_break_result: if check_break_result:
return {"check_break_result": True} return {"check_break_result": True}
@ -424,9 +424,9 @@ class LoopNode(BaseNode):
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.node_type,
loop_node_data=self.node_data, loop_node_data=self._node_data,
index=next_index, index=next_index,
pre_loop_output=self.node_data.outputs, pre_loop_output=self._node_data.outputs,
) )
return {"check_break_result": False} return {"check_break_result": False}

@ -16,28 +16,28 @@ class LoopStartNode(BaseNode):
_node_type = NodeType.LOOP_START _node_type = NodeType.LOOP_START
node_data: LoopStartNodeData _node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopStartNodeData(**data) self._node_data = LoopStartNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:

@ -94,28 +94,28 @@ class ParameterExtractorNode(BaseNode):
_node_type = NodeType.PARAMETER_EXTRACTOR _node_type = NodeType.PARAMETER_EXTRACTOR
node_data: ParameterExtractorNodeData _node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ParameterExtractorNodeData.model_validate(data) self._node_data = ParameterExtractorNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
_model_instance: Optional[ModelInstance] = None _model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None _model_config: Optional[ModelConfigWithCredentialsEntity] = None
@ -141,7 +141,7 @@ class ParameterExtractorNode(BaseNode):
""" """
Run the node. Run the node.
""" """
node_data = cast(ParameterExtractorNodeData, self.node_data) node_data = cast(ParameterExtractorNodeData, self._node_data)
variable = self.graph_runtime_state.variable_pool.get(node_data.query) variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else "" query = variable.text if variable else ""

@ -47,7 +47,7 @@ if TYPE_CHECKING:
class QuestionClassifierNode(BaseNode): class QuestionClassifierNode(BaseNode):
_node_type = NodeType.QUESTION_CLASSIFIER _node_type = NodeType.QUESTION_CLASSIFIER
node_data: QuestionClassifierNodeData _node_data: QuestionClassifierNodeData
_file_outputs: list["File"] _file_outputs: list["File"]
_llm_file_saver: LLMFileSaver _llm_file_saver: LLMFileSaver
@ -84,32 +84,32 @@ class QuestionClassifierNode(BaseNode):
self._llm_file_saver = llm_file_saver self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = QuestionClassifierNodeData.model_validate(data) self._node_data = QuestionClassifierNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls): def version(cls):
return "1" return "1"
def _run(self): def _run(self):
node_data = cast(QuestionClassifierNodeData, self.node_data) node_data = cast(QuestionClassifierNodeData, self._node_data)
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
# extract variables # extract variables

@ -13,28 +13,28 @@ from core.workflow.nodes.start.entities import StartNodeData
class StartNode(BaseNode): class StartNode(BaseNode):
_node_type = NodeType.START _node_type = NodeType.START
node_data: StartNodeData _node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = StartNodeData(**data) self._node_data = StartNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:

@ -16,28 +16,28 @@ MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MA
class TemplateTransformNode(BaseNode): class TemplateTransformNode(BaseNode):
_node_type = NodeType.TEMPLATE_TRANSFORM _node_type = NodeType.TEMPLATE_TRANSFORM
node_data: TemplateTransformNodeData _node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = TemplateTransformNodeData.model_validate(data) self._node_data = TemplateTransformNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -58,14 +58,14 @@ class TemplateTransformNode(BaseNode):
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get variables # Get variables
variables = {} variables = {}
for variable_selector in self.node_data.variables: for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = value.to_object() if value else None variables[variable_name] = value.to_object() if value else None
# Run code # Run code
try: try:
result = CodeExecutor.execute_workflow_code_template( result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
) )
except CodeExecutionError as e: except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))

@ -43,10 +43,10 @@ class ToolNode(BaseNode):
_node_type = NodeType.TOOL _node_type = NodeType.TOOL
node_data: ToolNodeData _node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ToolNodeData.model_validate(data) self._node_data = ToolNodeData.model_validate(data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -57,7 +57,7 @@ class ToolNode(BaseNode):
Run the tool node Run the tool node
""" """
node_data = cast(ToolNodeData, self.node_data) node_data = cast(ToolNodeData, self._node_data)
# fetch tool icon # fetch tool icon
tool_info = { tool_info = {
@ -70,9 +70,9 @@ class ToolNode(BaseNode):
try: try:
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None
tool_runtime = ToolManager.get_workflow_tool_runtime( tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
) )
except ToolNodeError as e: except ToolNodeError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
@ -91,12 +91,12 @@ class ToolNode(BaseNode):
parameters = self._generate_parameters( parameters = self._generate_parameters(
tool_parameters=tool_parameters, tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data, node_data=self._node_data,
) )
parameters_for_log = self._generate_parameters( parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters, tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data, node_data=self._node_data,
for_log=True, for_log=True,
) )
# get conversation id # get conversation id
@ -404,27 +404,27 @@ class ToolNode(BaseNode):
return result return result
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@property @property
def continue_on_error(self) -> bool: def continue_on_error(self) -> bool:
return self.node_data.error_strategy is not None return self._node_data.error_strategy is not None
@property @property
def retry(self) -> bool: def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled return self._node_data.retry_config.retry_enabled

@ -13,28 +13,28 @@ from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNod
class VariableAggregatorNode(BaseNode): class VariableAggregatorNode(BaseNode):
_node_type = NodeType.VARIABLE_AGGREGATOR _node_type = NodeType.VARIABLE_AGGREGATOR
node_data: VariableAssignerNodeData _node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData(**data) self._node_data = VariableAssignerNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -45,8 +45,8 @@ class VariableAggregatorNode(BaseNode):
outputs: dict[str, Segment | Mapping[str, Segment]] = {} outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {} inputs = {}
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables: for selector in self._node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None: if variable is not None:
outputs = {"output": variable} outputs = {"output": variable}
@ -54,7 +54,7 @@ class VariableAggregatorNode(BaseNode):
inputs = {".".join(selector[1:]): variable.to_object()} inputs = {".".join(selector[1:]): variable.to_object()}
break break
else: else:
for group in self.node_data.advanced_settings.groups: for group in self._node_data.advanced_settings.groups:
for selector in group.variables: for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)

@ -27,28 +27,28 @@ class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
node_data: VariableAssignerData _node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerData.model_validate(data) self._node_data = VariableAssignerData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
def __init__( def __init__(
self, self,
@ -100,21 +100,21 @@ class VariableAssignerNode(BaseNode):
return mapping return mapping
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
assigned_variable_selector = self.node_data.assigned_variable_selector assigned_variable_selector = self._node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable): if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found") raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode: match self._node_data.write_mode:
case WriteMode.OVER_WRITE: case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value: if not income_value:
raise VariableOperatorNodeError("input value not found") raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value}) updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND: case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value: if not income_value:
raise VariableOperatorNodeError("input value not found") raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value] updated_value = original_variable.value + [income_value.value]
@ -127,7 +127,7 @@ class VariableAssignerNode(BaseNode):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _: case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
# Over write the variable. # Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

@ -1,6 +1,6 @@
import json import json
from collections.abc import Callable, Mapping, MutableMapping, Sequence from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, Optional, TypeAlias, cast from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable from core.variables import SegmentType, Variable
@ -29,8 +29,6 @@ from .exc import (
VariableNotFoundError, VariableNotFoundError,
) )
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0] selector_node_id = item.variable_selector[0]
@ -58,28 +56,28 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(BaseNode): class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
node_data: VariableAssignerNodeData _node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData.model_validate(data) self._node_data = VariableAssignerNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]: def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig: def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config return self._node_data.retry_config
def get_title(self) -> str: def get_title(self) -> str:
return self.node_data.title return self._node_data.title
def get_description(self) -> Optional[str]: def get_description(self) -> Optional[str]:
return self.node_data.desc return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]: def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData: def get_base_node_data(self) -> BaseNodeData:
return self.node_data return self._node_data
def _conv_var_updater_factory(self) -> ConversationVariableUpdater: def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory() return conversation_variable_updater_factory()
@ -106,13 +104,13 @@ class VariableAssignerNode(BaseNode):
return var_mapping return var_mapping
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump() inputs = self._node_data.model_dump()
process_data: dict[str, Any] = {} process_data: dict[str, Any] = {}
# NOTE: This node has no outputs # NOTE: This node has no outputs
updated_variable_selectors: list[Sequence[str]] = [] updated_variable_selectors: list[Sequence[str]] = []
try: try:
for item in self.node_data.items: for item in self._node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part # ==================== Validation Part

@ -234,10 +234,10 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
} }
node.node_data = cast(CodeNodeData, node.node_data) node._node_data = cast(CodeNodeData, node._node_data)
# validate # validate
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
# construct result # construct result
result = { result = {
@ -250,7 +250,7 @@ def test_execute_code_output_validator_depth():
# validate # validate
with pytest.raises(ValueError): with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
# construct result # construct result
result = { result = {
@ -263,7 +263,7 @@ def test_execute_code_output_validator_depth():
# validate # validate
with pytest.raises(ValueError): with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
# construct result # construct result
result = { result = {
@ -276,7 +276,7 @@ def test_execute_code_output_validator_depth():
# validate # validate
with pytest.raises(ValueError): with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
def test_execute_code_output_object_list(): def test_execute_code_output_object_list():
@ -330,10 +330,10 @@ def test_execute_code_output_object_list():
] ]
} }
node.node_data = cast(CodeNodeData, node.node_data) node._node_data = cast(CodeNodeData, node._node_data)
# validate # validate
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
# construct result # construct result
result = { result = {
@ -353,7 +353,7 @@ def test_execute_code_output_object_list():
# validate # validate
with pytest.raises(ValueError): with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
def test_execute_code_scientific_notation(): def test_execute_code_scientific_notation():

@ -665,8 +665,8 @@ def test_iteration_run_in_parallel_mode():
# execute node # execute node
parallel_result = parallel_iteration_node._run() parallel_result = parallel_iteration_node._run()
sequential_result = sequential_iteration_node._run() sequential_result = sequential_iteration_node._run()
assert parallel_iteration_node.node_data.parallel_nums == 10 assert parallel_iteration_node._node_data.parallel_nums == 10
assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
count = 0 count = 0
parallel_arr = [] parallel_arr = []
sequential_arr = [] sequential_arr = []
@ -876,7 +876,7 @@ def test_iteration_run_error_handle():
assert count == 14 assert count == 14
# execute remove abnormal output # execute remove abnormal output
iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
result = iteration_node._run() result = iteration_node._run()
count = 0 count = 0
for item in result: for item in result:

Loading…
Cancel
Save