refactor(nodes): rename `node_data` to `_node_data`

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 956f9158c7
commit ba40054f44
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -63,28 +63,28 @@ class AgentNode(BaseNode):
"""
_node_type = NodeType.AGENT
node_data: AgentNodeData
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = AgentNodeData.model_validate(data)
self._node_data = AgentNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:
@ -94,7 +94,7 @@ class AgentNode(BaseNode):
"""
Run the agent node
"""
node_data = cast(AgentNodeData, self.node_data)
node_data = cast(AgentNodeData, self._node_data)
try:
strategy = get_plugin_agent_strategy(
@ -160,18 +160,18 @@ class AgentNode(BaseNode):
type=ToolInvokeMessage.MessageType.LOG,
message=ToolInvokeMessage.LogMessage(
id=str(uuid.uuid4()),
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
label=f"Agent Strategy: {cast(AgentNodeData, self._node_data).agent_strategy_name}",
parent_id=None,
error=None,
status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
"parameters": parameters_for_log,
"thought_process": "Agent strategy execution started",
},
metadata={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
},
),
)
@ -180,7 +180,7 @@ class AgentNode(BaseNode):
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
@ -299,7 +299,7 @@ class AgentNode(BaseNode):
)
extra = tool.get("extra", {})
runtime_variable_pool = variable_pool if self.node_data.version != "1" else None
runtime_variable_pool = variable_pool if self._node_data.version != "1" else None
tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
)
@ -415,7 +415,7 @@ class AgentNode(BaseNode):
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self.node_data).agent_strategy_provider_name
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:

@ -20,28 +20,28 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode):
_node_type = NodeType.ANSWER
node_data: AnswerNodeData
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = AnswerNodeData.model_validate(data)
self._node_data = AnswerNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:
@ -53,7 +53,7 @@ class AnswerNode(BaseNode):
:return:
"""
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
answer = ""
files = []

@ -25,28 +25,28 @@ from .exc import (
class CodeNode(BaseNode):
_node_type = NodeType.CODE
node_data: CodeNodeData
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = CodeNodeData.model_validate(data)
self._node_data = CodeNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -70,12 +70,12 @@ class CodeNode(BaseNode):
def _run(self) -> NodeRunResult:
# Get code language
code_language = self.node_data.code_language
code = self.node_data.code
code_language = self._node_data.code_language
code = self._node_data.code
# Get variables
variables = {}
for variable_selector in self.node_data.variables:
for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment):
@ -91,7 +91,7 @@ class CodeNode(BaseNode):
)
# Transform result
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -369,8 +369,8 @@ class CodeNode(BaseNode):
@property
def continue_on_error(self) -> bool:
return self.node_data.error_strategy is not None
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
return self._node_data.retry_config.retry_enabled

@ -45,35 +45,35 @@ class DocumentExtractorNode(BaseNode):
_node_type = NodeType.DOCUMENT_EXTRACTOR
node_data: DocumentExtractorNodeData
_node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = DocumentExtractorNodeData.model_validate(data)
self._node_data = DocumentExtractorNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
variable_selector = self.node_data.variable_selector
variable_selector = self._node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:

@ -12,28 +12,28 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType
class EndNode(BaseNode):
_node_type = NodeType.END
node_data: EndNodeData
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = EndNodeData(**data)
self._node_data = EndNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:
@ -44,7 +44,7 @@ class EndNode(BaseNode):
Run node
:return:
"""
output_variables = self.node_data.outputs
output_variables = self._node_data.outputs
outputs = {}
for variable_selector in output_variables:

@ -36,28 +36,28 @@ logger = logging.getLogger(__name__)
class HttpRequestNode(BaseNode):
_node_type = NodeType.HTTP_REQUEST
node_data: HttpRequestNodeData
_node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = HttpRequestNodeData.model_validate(data)
self._node_data = HttpRequestNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
@ -92,8 +92,8 @@ class HttpRequestNode(BaseNode):
process_data = {}
try:
http_executor = Executor(
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
node_data=self._node_data,
timeout=self._get_request_timeout(self._node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
@ -246,8 +246,8 @@ class HttpRequestNode(BaseNode):
@property
def continue_on_error(self) -> bool:
return self.node_data.error_strategy is not None
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
return self._node_data.retry_config.retry_enabled

@ -17,28 +17,28 @@ from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(BaseNode):
_node_type = NodeType.IF_ELSE
node_data: IfElseNodeData
_node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = IfElseNodeData.model_validate(data)
self._node_data = IfElseNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:
@ -59,8 +59,8 @@ class IfElseNode(BaseNode):
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if self.node_data.cases:
for case in self.node_data.cases:
if self._node_data.cases:
for case in self._node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions,
@ -86,8 +86,8 @@ class IfElseNode(BaseNode):
input_conditions, group_result, final_result = _should_not_use_old_function(
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self.node_data.conditions or [],
operator=self.node_data.logical_operator or "and",
conditions=self._node_data.conditions or [],
operator=self._node_data.logical_operator or "and",
)
selected_case_id = "true" if final_result else "false"

@ -64,28 +64,28 @@ class IterationNode(BaseNode):
_node_type = NodeType.ITERATION
node_data: IterationNodeData
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = IterationNodeData.model_validate(data)
self._node_data = IterationNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -106,10 +106,10 @@ class IterationNode(BaseNode):
"""
Run the node.
"""
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
@ -139,10 +139,10 @@ class IterationNode(BaseNode):
graph_config = self.graph_config
if not self.node_data.start_node_id:
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self.node_data.start_node_id
root_node_id = self._node_data.start_node_id
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
@ -185,7 +185,7 @@ class IterationNode(BaseNode):
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"iterator_length": len(iterator_list_value)},
@ -196,7 +196,7 @@ class IterationNode(BaseNode):
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
index=0,
pre_iteration_output=None,
duration=None,
@ -204,11 +204,11 @@ class IterationNode(BaseNode):
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self.node_data.is_parallel:
if self._node_data.is_parallel:
futures: list[Future] = []
q: Queue = Queue()
thread_pool = GraphEngineThreadPool(
max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
@ -265,7 +265,7 @@ class IterationNode(BaseNode):
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
)
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
# Flatten the list of lists
@ -277,7 +277,7 @@ class IterationNode(BaseNode):
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@ -302,7 +302,7 @@ class IterationNode(BaseNode):
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@ -394,7 +394,7 @@ class IterationNode(BaseNode):
"""
if not isinstance(event, BaseNodeEvent):
return event
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
iter_metadata = {
@ -457,12 +457,12 @@ class IterationNode(BaseNode):
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
if self.node_data.is_parallel:
if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
@ -476,7 +476,7 @@ class IterationNode(BaseNode):
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@ -497,7 +497,7 @@ class IterationNode(BaseNode):
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
@ -511,14 +511,14 @@ class IterationNode(BaseNode):
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
@ -532,14 +532,14 @@ class IterationNode(BaseNode):
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
@ -550,12 +550,12 @@ class IterationNode(BaseNode):
variable_pool.remove([node_id])
# iteration run failed
if self.node_data.is_parallel:
if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
@ -569,7 +569,7 @@ class IterationNode(BaseNode):
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@ -588,7 +588,7 @@ class IterationNode(BaseNode):
return
yield metadata_event
current_output_segment = variable_pool.get(self.node_data.output_selector)
current_output_segment = variable_pool.get(self._node_data.output_selector)
if current_output_segment is None:
raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value
@ -608,7 +608,7 @@ class IterationNode(BaseNode):
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=current_iteration_output or None,
@ -621,7 +621,7 @@ class IterationNode(BaseNode):
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},

@ -86,7 +86,7 @@ default_retrieval_model = {
class KnowledgeRetrievalNode(BaseNode):
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
node_data: KnowledgeRetrievalNodeData
_node_data: KnowledgeRetrievalNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
@ -126,32 +126,32 @@ class KnowledgeRetrievalNode(BaseNode):
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = KnowledgeRetrievalNodeData.model_validate(data)
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls):
return "1"
def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
node_data = cast(KnowledgeRetrievalNodeData, self._node_data)
# extract variables
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
@ -545,7 +545,7 @@ class KnowledgeRetrievalNode(BaseNode):
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,

@ -17,28 +17,28 @@ from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError
class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
node_data: ListOperatorNodeData
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ListOperatorNodeData(**data)
self._node_data = ListOperatorNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:
@ -49,9 +49,9 @@ class ListOperatorNode(BaseNode):
process_data: dict[str, list] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
if variable is None:
error_message = f"Variable not found for selector: {self.node_data.variable}"
error_message = f"Variable not found for selector: {self._node_data.variable}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -71,7 +71,7 @@ class ListOperatorNode(BaseNode):
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
return NodeRunResult(
@ -87,19 +87,19 @@ class ListOperatorNode(BaseNode):
try:
# Filter
if self.node_data.filter_by.enabled:
if self._node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Extract
if self.node_data.extract_by.enabled:
if self._node_data.extract_by.enabled:
variable = self._extract_slice(variable)
# Order
if self.node_data.order_by.enabled:
if self._node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
if self.node_data.limit.enabled:
if self._node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
@ -127,7 +127,7 @@ class ListOperatorNode(BaseNode):
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self.node_data.filter_by.conditions:
for condition in self._node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
@ -160,14 +160,14 @@ class ListOperatorNode(BaseNode):
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
@ -175,13 +175,13 @@ class ListOperatorNode(BaseNode):
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
result = variable.value[: self.node_data.limit.size]
result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
value -= 1

@ -99,7 +99,7 @@ logger = logging.getLogger(__name__)
class LLMNode(BaseNode):
_node_type = NodeType.LLM
node_data: LLMNodeData
_node_data: LLMNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
@ -139,25 +139,25 @@ class LLMNode(BaseNode):
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LLMNodeData.model_validate(data)
self._node_data = LLMNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:
@ -173,13 +173,13 @@ class LLMNode(BaseNode):
try:
# init messages template
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data)
inputs = self._fetch_inputs(node_data=self._node_data)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
# merge inputs
inputs.update(jinja_inputs)
@ -190,9 +190,9 @@ class LLMNode(BaseNode):
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=self.node_data.vision.configs.variable_selector,
selector=self._node_data.vision.configs.variable_selector,
)
if self.node_data.vision.enabled
if self._node_data.vision.enabled
else []
)
@ -200,7 +200,7 @@ class LLMNode(BaseNode):
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data=self.node_data)
generator = self._fetch_context(node_data=self._node_data)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
@ -211,7 +211,7 @@ class LLMNode(BaseNode):
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self.node_data.model,
node_data_model=self._node_data.model,
tenant_id=self.tenant_id,
)
@ -219,13 +219,13 @@ class LLMNode(BaseNode):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self.node_data.memory,
node_data_memory=self._node_data.memory,
model_instance=model_instance,
)
query = None
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if self._node_data.memory:
query = self._node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
@ -237,24 +237,24 @@ class LLMNode(BaseNode):
context=context,
memory=memory,
model_config=model_config,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
prompt_template=self._node_data.prompt_template,
memory_config=self._node_data.memory,
vision_enabled=self._node_data.vision.enabled,
vision_detail=self._node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
jinja2_variables=self._node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
)
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=self.node_data.model,
node_data_model=self._node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=self.node_data.structured_output,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self._node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
@ -1010,7 +1010,7 @@ class LLMNode(BaseNode):
"""
Fetch model schema
"""
model_name = self.node_data.model.name
model_name = self._node_data.model.name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
@ -1089,11 +1089,11 @@ class LLMNode(BaseNode):
@property
def continue_on_error(self) -> bool:
return self.node_data.error_strategy is not None
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
return self._node_data.retry_config.retry_enabled
def _combine_message_content_with_role(

@ -16,28 +16,28 @@ class LoopEndNode(BaseNode):
_node_type = NodeType.LOOP_END
node_data: LoopEndNodeData
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopEndNodeData(**data)
self._node_data = LoopEndNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:

@ -51,28 +51,28 @@ class LoopNode(BaseNode):
_node_type = NodeType.LOOP
node_data: LoopNodeData
_node_data: LoopNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopNodeData.model_validate(data)
self._node_data = LoopNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:
@ -81,17 +81,17 @@ class LoopNode(BaseNode):
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node."""
# Get inputs
loop_count = self.node_data.loop_count
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
loop_count = self._node_data.loop_count
break_conditions = self._node_data.break_conditions
logical_operator = self._node_data.logical_operator
inputs = {"loop_count": loop_count}
if not self.node_data.start_node_id:
if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id)
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@ -101,8 +101,8 @@ class LoopNode(BaseNode):
# Initialize loop variables
loop_variable_selectors = {}
if self.node_data.loop_variables:
for loop_variable in self.node_data.loop_variables:
if self._node_data.loop_variables:
for loop_variable in self._node_data.loop_variables:
value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value),
@ -151,7 +151,7 @@ class LoopNode(BaseNode):
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"loop_length": loop_count},
@ -208,10 +208,10 @@ class LoopNode(BaseNode):
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs=self.node_data.outputs,
outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
@ -229,7 +229,7 @@ class LoopNode(BaseNode):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self.node_data.outputs,
outputs=self._node_data.outputs,
inputs=inputs,
)
)
@ -241,7 +241,7 @@ class LoopNode(BaseNode):
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=loop_count,
@ -344,7 +344,7 @@ class LoopNode(BaseNode):
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
@ -375,7 +375,7 @@ class LoopNode(BaseNode):
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
@ -411,7 +411,7 @@ class LoopNode(BaseNode):
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
self.node_data.outputs = _outputs
self._node_data.outputs = _outputs
if check_break_result:
return {"check_break_result": True}
@ -424,9 +424,9 @@ class LoopNode(BaseNode):
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_data=self._node_data,
index=next_index,
pre_loop_output=self.node_data.outputs,
pre_loop_output=self._node_data.outputs,
)
return {"check_break_result": False}

@ -16,28 +16,28 @@ class LoopStartNode(BaseNode):
_node_type = NodeType.LOOP_START
node_data: LoopStartNodeData
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopStartNodeData(**data)
self._node_data = LoopStartNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:

@ -94,28 +94,28 @@ class ParameterExtractorNode(BaseNode):
_node_type = NodeType.PARAMETER_EXTRACTOR
node_data: ParameterExtractorNodeData
_node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ParameterExtractorNodeData.model_validate(data)
self._node_data = ParameterExtractorNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
@ -141,7 +141,7 @@ class ParameterExtractorNode(BaseNode):
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
node_data = cast(ParameterExtractorNodeData, self._node_data)
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""

@ -47,7 +47,7 @@ if TYPE_CHECKING:
class QuestionClassifierNode(BaseNode):
_node_type = NodeType.QUESTION_CLASSIFIER
node_data: QuestionClassifierNodeData
_node_data: QuestionClassifierNodeData
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
@ -84,32 +84,32 @@ class QuestionClassifierNode(BaseNode):
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = QuestionClassifierNodeData.model_validate(data)
self._node_data = QuestionClassifierNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls):
return "1"
def _run(self):
node_data = cast(QuestionClassifierNodeData, self.node_data)
node_data = cast(QuestionClassifierNodeData, self._node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables

@ -13,28 +13,28 @@ from core.workflow.nodes.start.entities import StartNodeData
class StartNode(BaseNode):
_node_type = NodeType.START
node_data: StartNodeData
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = StartNodeData(**data)
self._node_data = StartNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:

@ -16,28 +16,28 @@ MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MA
class TemplateTransformNode(BaseNode):
_node_type = NodeType.TEMPLATE_TRANSFORM
node_data: TemplateTransformNodeData
_node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = TemplateTransformNodeData.model_validate(data)
self._node_data = TemplateTransformNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -58,14 +58,14 @@ class TemplateTransformNode(BaseNode):
def _run(self) -> NodeRunResult:
# Get variables
variables = {}
for variable_selector in self.node_data.variables:
for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
)
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))

@ -43,10 +43,10 @@ class ToolNode(BaseNode):
_node_type = NodeType.TOOL
node_data: ToolNodeData
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ToolNodeData.model_validate(data)
self._node_data = ToolNodeData.model_validate(data)
@classmethod
def version(cls) -> str:
@ -57,7 +57,7 @@ class ToolNode(BaseNode):
Run the tool node
"""
node_data = cast(ToolNodeData, self.node_data)
node_data = cast(ToolNodeData, self._node_data)
# fetch tool icon
tool_info = {
@ -70,9 +70,9 @@ class ToolNode(BaseNode):
try:
from core.tools.tool_manager import ToolManager
variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None
variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool
self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield RunCompletedEvent(
@ -91,12 +91,12 @@ class ToolNode(BaseNode):
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
node_data=self._node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
node_data=self._node_data,
for_log=True,
)
# get conversation id
@ -404,27 +404,27 @@ class ToolNode(BaseNode):
return result
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@property
def continue_on_error(self) -> bool:
return self.node_data.error_strategy is not None
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
return self._node_data.retry_config.retry_enabled

@ -13,28 +13,28 @@ from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNod
class VariableAggregatorNode(BaseNode):
_node_type = NodeType.VARIABLE_AGGREGATOR
node_data: VariableAssignerNodeData
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData(**data)
self._node_data = VariableAssignerNodeData(**data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
@classmethod
def version(cls) -> str:
@ -45,8 +45,8 @@ class VariableAggregatorNode(BaseNode):
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {}
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
for selector in self._node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable}
@ -54,7 +54,7 @@ class VariableAggregatorNode(BaseNode):
inputs = {".".join(selector[1:]): variable.to_object()}
break
else:
for group in self.node_data.advanced_settings.groups:
for group in self._node_data.advanced_settings.groups:
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)

@ -27,28 +27,28 @@ class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
node_data: VariableAssignerData
_node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerData.model_validate(data)
self._node_data = VariableAssignerData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
def __init__(
self,
@ -100,21 +100,21 @@ class VariableAssignerNode(BaseNode):
return mapping
def _run(self) -> NodeRunResult:
assigned_variable_selector = self.node_data.assigned_variable_selector
assigned_variable_selector = self._node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:
match self._node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
@ -127,7 +127,7 @@ class VariableAssignerNode(BaseNode):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

@ -1,6 +1,6 @@
import json
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any, Optional, TypeAlias, cast
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
@ -29,8 +29,6 @@ from .exc import (
VariableNotFoundError,
)
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
@ -58,28 +56,28 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
node_data: VariableAssignerNodeData
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData.model_validate(data)
self._node_data = VariableAssignerNodeData.model_validate(data)
def get_error_strategy(self) -> Optional[ErrorStrategy]:
return self.node_data.error_strategy
return self._node_data.error_strategy
def get_retry_config(self) -> RetryConfig:
return self.node_data.retry_config
return self._node_data.retry_config
def get_title(self) -> str:
return self.node_data.title
return self._node_data.title
def get_description(self) -> Optional[str]:
return self.node_data.desc
return self._node_data.desc
def get_default_value_dict(self) -> dict[str, Any]:
return self.node_data.default_value_dict
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self.node_data
return self._node_data
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()
@ -106,13 +104,13 @@ class VariableAssignerNode(BaseNode):
return var_mapping
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
inputs = self._node_data.model_dump()
process_data: dict[str, Any] = {}
# NOTE: This node has no outputs
updated_variable_selectors: list[Sequence[str]] = []
try:
for item in self.node_data.items:
for item in self._node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part

@ -234,10 +234,10 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
node.node_data = cast(CodeNodeData, node.node_data)
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@ -250,7 +250,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@ -263,7 +263,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@ -276,7 +276,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
def test_execute_code_output_object_list():
@ -330,10 +330,10 @@ def test_execute_code_output_object_list():
]
}
node.node_data = cast(CodeNodeData, node.node_data)
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@ -353,7 +353,7 @@ def test_execute_code_output_object_list():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
def test_execute_code_scientific_notation():

@ -665,8 +665,8 @@ def test_iteration_run_in_parallel_mode():
# execute node
parallel_result = parallel_iteration_node._run()
sequential_result = sequential_iteration_node._run()
assert parallel_iteration_node.node_data.parallel_nums == 10
assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED
assert parallel_iteration_node._node_data.parallel_nums == 10
assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
count = 0
parallel_arr = []
sequential_arr = []
@ -876,7 +876,7 @@ def test_iteration_run_error_handle():
assert count == 14
# execute remove abnormal output
iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
result = iteration_node._run()
count = 0
for item in result:

Loading…
Cancel
Save