refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
pull/22617/head
-LAN- 10 months ago committed by GitHub
parent 54c56f2d05
commit 460a825ef1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -17,7 +17,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom

@ -15,7 +15,8 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom

@ -169,7 +169,3 @@ class AppQueueManager:
raise TypeError( raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed." "Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
) )
class GenerateTaskStoppedError(Exception):
pass

@ -118,7 +118,7 @@ class AppRunner:
else: else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode(model_config.mode)
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION: if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template

@ -11,10 +11,11 @@ from configs import dify_config
from constants import UUID_NIL from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom

@ -10,10 +10,11 @@ from pydantic import ValidationError
from configs import dify_config from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom

@ -0,0 +1,2 @@
class GenerateTaskStoppedError(Exception):
pass

@ -6,7 +6,8 @@ from typing import Optional, Union, cast
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity, AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity, AgentChatAppGenerateEntity,

@ -1,4 +1,5 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,

@ -13,7 +13,8 @@ import contexts
from configs import dify_config from configs import dify_config
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow.app_runner import WorkflowAppRunner

@ -1,4 +1,5 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,

@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum):
COMPLETION = "completion" COMPLETION = "completion"
CHAT = "chat" CHAT = "chat"
@classmethod
def value_of(cls, value: str) -> "ModelMode":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
prompt_file_contents: dict[str, Any] = {} prompt_file_contents: dict[str, Any] = {}
@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
) -> tuple[list[PromptMessage], Optional[list[str]]]: ) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = {key: str(value) for key, value in inputs.items()} inputs = {key: str(value) for key, value in inputs.items()}
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode(model_config.mode)
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages( prompt_messages, stops = self._get_chat_model_prompt_messages(
app_mode=app_mode, app_mode=app_mode,

@ -1137,7 +1137,7 @@ class DatasetRetrieval:
def _get_prompt_template( def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
): ):
model_mode = ModelMode.value_of(mode) model_mode = ModelMode(mode)
input_text = query input_text = query
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]

@ -2,7 +2,7 @@ from core.workflow.nodes.base import BaseNode
class WorkflowNodeRunFailedError(Exception): class WorkflowNodeRunFailedError(Exception):
def __init__(self, node_instance: BaseNode, error: str): def __init__(self, node: BaseNode, err_msg: str):
self.node_instance = node_instance self._node = node
self.error = error self._error = err_msg
super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") super().__init__(f"Node {node.title} run failed: {err_msg}")

@ -1,3 +1,4 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
from .graph_engine import GraphEngine
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] __all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

@ -12,7 +12,7 @@ from typing import Any, Optional, cast
from flask import Flask, current_app from flask import Flask, current_app
from configs import dify_config from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.variable_pool import VariablePool, VariableValue
@ -48,11 +48,9 @@ from core.workflow.nodes.agent.entities import AgentNodeData
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.utils import variable_utils from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom from models.enums import UserFrom
@ -260,12 +258,16 @@ class GraphEngine:
# convert to specific node # convert to specific node
node_type = NodeType(node_config.get("data", {}).get("type")) node_type = NodeType(node_config.get("data", {}).get("type"))
node_version = node_config.get("data", {}).get("version", "1") node_version = node_config.get("data", {}).get("version", "1")
# Import here to avoid circular import
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state # init workflow run state
node_instance = node_cls( # type: ignore node = node_cls(
id=route_node_state.id, id=route_node_state.id,
config=node_config, config=node_config,
graph_init_params=self.init_params, graph_init_params=self.init_params,
@ -274,11 +276,11 @@ class GraphEngine:
previous_node_id=previous_node_id, previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )
node_instance = cast(BaseNode[BaseNodeData], node_instance) node.init_node_data(node_config.get("data", {}))
try: try:
# run node # run node
generator = self._run_node( generator = self._run_node(
node_instance=node_instance, node=node,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=in_parallel_id, parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
@ -306,16 +308,16 @@ class GraphEngine:
route_node_state.failed_reason = str(e) route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent( yield NodeRunFailedEvent(
error=str(e), error=str(e),
id=node_instance.id, id=node.id,
node_id=next_node_id, node_id=next_node_id,
node_type=node_type, node_type=node_type,
node_data=node_instance.node_data, node_data=node.get_base_node_data(),
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=in_parallel_id, parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(), node_version=node.version(),
) )
raise e raise e
@ -337,7 +339,7 @@ class GraphEngine:
edge = edge_mappings[0] edge = edge_mappings[0]
if ( if (
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH and node.error_strategy == ErrorStrategy.FAIL_BRANCH
and edge.run_condition is None and edge.run_condition is None
): ):
break break
@ -413,8 +415,8 @@ class GraphEngine:
next_node_id = final_node_id next_node_id = final_node_id
elif ( elif (
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH node.continue_on_error
and node_instance.should_continue_on_error and node.error_strategy == ErrorStrategy.FAIL_BRANCH
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
): ):
break break
@ -597,7 +599,7 @@ class GraphEngine:
def _run_node( def _run_node(
self, self,
node_instance: BaseNode[BaseNodeData], node: BaseNode,
route_node_state: RouteNodeState, route_node_state: RouteNodeState,
parallel_id: Optional[str] = None, parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None,
@ -611,29 +613,29 @@ class GraphEngine:
# trigger node run start event # trigger node run start event
agent_strategy = ( agent_strategy = (
AgentNodeStrategyInit( AgentNodeStrategyInit(
name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name, name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name,
icon=cast(AgentNode, node_instance).agent_strategy_icon, icon=cast(AgentNode, node).agent_strategy_icon,
) )
if node_instance.node_type == NodeType.AGENT if node.type_ == NodeType.AGENT
else None else None
) )
yield NodeRunStartedEvent( yield NodeRunStartedEvent(
id=node_instance.id, id=node.id,
node_id=node_instance.node_id, node_id=node.node_id,
node_type=node_instance.node_type, node_type=node.type_,
node_data=node_instance.node_data, node_data=node.get_base_node_data(),
route_node_state=route_node_state, route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id, predecessor_node_id=node.previous_node_id,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy, agent_strategy=agent_strategy,
node_version=node_instance.version(), node_version=node.version(),
) )
max_retries = node_instance.node_data.retry_config.max_retries max_retries = node.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds retry_interval = node.retry_config.retry_interval_seconds
retries = 0 retries = 0
should_continue_retry = True should_continue_retry = True
while should_continue_retry and retries <= max_retries: while should_continue_retry and retries <= max_retries:
@ -642,7 +644,7 @@ class GraphEngine:
retry_start_at = datetime.now(UTC).replace(tzinfo=None) retry_start_at = datetime.now(UTC).replace(tzinfo=None)
# yield control to other threads # yield control to other threads
time.sleep(0.001) time.sleep(0.001)
event_stream = node_instance.run() event_stream = node.run()
for event in event_stream: for event in event_stream:
if isinstance(event, GraphEngineEvent): if isinstance(event, GraphEngineEvent):
# add parallel info to iteration event # add parallel info to iteration event
@ -658,21 +660,21 @@ class GraphEngine:
if run_result.status == WorkflowNodeExecutionStatus.FAILED: if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if ( if (
retries == max_retries retries == max_retries
and node_instance.node_type == NodeType.HTTP_REQUEST and node.type_ == NodeType.HTTP_REQUEST
and run_result.outputs and run_result.outputs
and not node_instance.should_continue_on_error and not node.continue_on_error
): ):
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
if node_instance.should_retry and retries < max_retries: if node.retry and retries < max_retries:
retries += 1 retries += 1
route_node_state.node_run_result = run_result route_node_state.node_run_result = run_result
yield NodeRunRetryEvent( yield NodeRunRetryEvent(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
node_id=node_instance.node_id, node_id=node.node_id,
node_type=node_instance.node_type, node_type=node.type_,
node_data=node_instance.node_data, node_data=node.get_base_node_data(),
route_node_state=route_node_state, route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id, predecessor_node_id=node.previous_node_id,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
@ -680,17 +682,17 @@ class GraphEngine:
error=run_result.error or "Unknown error", error=run_result.error or "Unknown error",
retry_index=retries, retry_index=retries,
start_at=retry_start_at, start_at=retry_start_at,
node_version=node_instance.version(), node_version=node.version(),
) )
time.sleep(retry_interval) time.sleep(retry_interval)
break break
route_node_state.set_finished(run_result=run_result) route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED: if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error: if node.continue_on_error:
# if run failed, handle error # if run failed, handle error
run_result = self._handle_continue_on_error( run_result = self._handle_continue_on_error(
node_instance, node,
event.run_result, event.run_result,
self.graph_runtime_state.variable_pool, self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions, handle_exceptions=handle_exceptions,
@ -701,44 +703,44 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items(): for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively # append variables to variable pool recursively
self._append_variables_recursively( self._append_variables_recursively(
node_id=node_instance.node_id, node_id=node.node_id,
variable_key_list=[variable_key], variable_key_list=[variable_key],
variable_value=variable_value, variable_value=variable_value,
) )
yield NodeRunExceptionEvent( yield NodeRunExceptionEvent(
error=run_result.error or "System Error", error=run_result.error or "System Error",
id=node_instance.id, id=node.id,
node_id=node_instance.node_id, node_id=node.node_id,
node_type=node_instance.node_type, node_type=node.type_,
node_data=node_instance.node_data, node_data=node.get_base_node_data(),
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(), node_version=node.version(),
) )
should_continue_retry = False should_continue_retry = False
else: else:
yield NodeRunFailedEvent( yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.", error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id, id=node.id,
node_id=node_instance.node_id, node_id=node.node_id,
node_type=node_instance.node_type, node_type=node.type_,
node_data=node_instance.node_data, node_data=node.get_base_node_data(),
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(), node_version=node.version(),
) )
should_continue_retry = False should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if ( if (
node_instance.should_continue_on_error node.continue_on_error
and self.graph.edge_mapping.get(node_instance.node_id) and self.graph.edge_mapping.get(node.node_id)
and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH and node.error_strategy is ErrorStrategy.FAIL_BRANCH
): ):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get( if run_result.metadata and run_result.metadata.get(
@ -758,7 +760,7 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items(): for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively # append variables to variable pool recursively
self._append_variables_recursively( self._append_variables_recursively(
node_id=node_instance.node_id, node_id=node.node_id,
variable_key_list=[variable_key], variable_key_list=[variable_key],
variable_value=variable_value, variable_value=variable_value,
) )
@ -783,26 +785,26 @@ class GraphEngine:
run_result.metadata = metadata_dict run_result.metadata = metadata_dict
yield NodeRunSucceededEvent( yield NodeRunSucceededEvent(
id=node_instance.id, id=node.id,
node_id=node_instance.node_id, node_id=node.node_id,
node_type=node_instance.node_type, node_type=node.type_,
node_data=node_instance.node_data, node_data=node.get_base_node_data(),
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(), node_version=node.version(),
) )
should_continue_retry = False should_continue_retry = False
break break
elif isinstance(event, RunStreamChunkEvent): elif isinstance(event, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent( yield NodeRunStreamChunkEvent(
id=node_instance.id, id=node.id,
node_id=node_instance.node_id, node_id=node.node_id,
node_type=node_instance.node_type, node_type=node.type_,
node_data=node_instance.node_data, node_data=node.get_base_node_data(),
chunk_content=event.chunk_content, chunk_content=event.chunk_content,
from_variable_selector=event.from_variable_selector, from_variable_selector=event.from_variable_selector,
route_node_state=route_node_state, route_node_state=route_node_state,
@ -810,14 +812,14 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(), node_version=node.version(),
) )
elif isinstance(event, RunRetrieverResourceEvent): elif isinstance(event, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent( yield NodeRunRetrieverResourceEvent(
id=node_instance.id, id=node.id,
node_id=node_instance.node_id, node_id=node.node_id,
node_type=node_instance.node_type, node_type=node.type_,
node_data=node_instance.node_data, node_data=node.get_base_node_data(),
retriever_resources=event.retriever_resources, retriever_resources=event.retriever_resources,
context=event.context, context=event.context,
route_node_state=route_node_state, route_node_state=route_node_state,
@ -825,7 +827,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(), node_version=node.version(),
) )
except GenerateTaskStoppedError: except GenerateTaskStoppedError:
# trigger node run failed event # trigger node run failed event
@ -833,20 +835,20 @@ class GraphEngine:
route_node_state.failed_reason = "Workflow stopped." route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent( yield NodeRunFailedEvent(
error="Workflow stopped.", error="Workflow stopped.",
id=node_instance.id, id=node.id,
node_id=node_instance.node_id, node_id=node.node_id,
node_type=node_instance.node_type, node_type=node.type_,
node_data=node_instance.node_data, node_data=node.get_base_node_data(),
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(), node_version=node.version(),
) )
return return
except Exception as e: except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed") logger.exception(f"Node {node.title} run failed")
raise e raise e
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
@ -886,22 +888,14 @@ class GraphEngine:
def _handle_continue_on_error( def _handle_continue_on_error(
self, self,
node_instance: BaseNode[BaseNodeData], node: BaseNode,
error_result: NodeRunResult, error_result: NodeRunResult,
variable_pool: VariablePool, variable_pool: VariablePool,
handle_exceptions: list[str] = [], handle_exceptions: list[str] = [],
) -> NodeRunResult: ) -> NodeRunResult:
"""
handle continue on error when self._should_continue_on_error is True
:param error_result (NodeRunResult): error run result
:param variable_pool (VariablePool): variable pool
:return: excption run result
"""
# add error message and error type to variable pool # add error message and error type to variable pool
variable_pool.add([node_instance.node_id, "error_message"], error_result.error) variable_pool.add([node.node_id, "error_message"], error_result.error)
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) variable_pool.add([node.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions # add error message to handle_exceptions
handle_exceptions.append(error_result.error or "") handle_exceptions.append(error_result.error or "")
node_error_args: dict[str, Any] = { node_error_args: dict[str, Any] = {
@ -909,21 +903,21 @@ class GraphEngine:
"error": error_result.error, "error": error_result.error,
"inputs": error_result.inputs, "inputs": error_result.inputs,
"metadata": { "metadata": {
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy,
}, },
} }
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
return NodeRunResult( return NodeRunResult(
**node_error_args, **node_error_args,
outputs={ outputs={
**node_instance.node_data.default_value_dict, **node.default_value_dict,
"error_message": error_result.error, "error_message": error_result.error,
"error_type": error_result.error_type, "error_type": error_result.error_type,
}, },
) )
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH: elif node.error_strategy is ErrorStrategy.FAIL_BRANCH:
if self.graph.edge_mapping.get(node_instance.node_id): if self.graph.edge_mapping.get(node.node_id):
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
return NodeRunResult( return NodeRunResult(
**node_error_args, **node_error_args,

@ -1,5 +1,4 @@
import json import json
import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
@ -11,8 +10,10 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter from core.agent.plugin_entities import AgentStrategyParameter
from core.agent.strategy.plugin import PluginAgentStrategy from core.agent.strategy.plugin import PluginAgentStrategy
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.entities.request import InvokeCredentials from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
@ -25,45 +26,75 @@ from core.tools.entities.tool_entities import (
ToolProviderType, ToolProviderType,
) )
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation from models.model import Conversation
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
AgentInputTypeError,
AgentInvocationError,
AgentMessageTransformError,
AgentVariableNotFoundError,
AgentVariableTypeError,
ToolFileNotFoundError,
)
class AgentNode(ToolNode): class AgentNode(BaseNode):
""" """
Agent Node Agent Node
""" """
_node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT _node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
"""
Run the agent node
"""
node_data = cast(AgentNodeData, self.node_data)
try: try:
strategy = get_plugin_agent_strategy( strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
agent_strategy_provider_name=node_data.agent_strategy_provider_name, agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
agent_strategy_name=node_data.agent_strategy_name, agent_strategy_name=self._node_data.agent_strategy_name,
) )
except Exception as e: except Exception as e:
yield RunCompletedEvent( yield RunCompletedEvent(
@ -81,13 +112,13 @@ class AgentNode(ToolNode):
parameters = self._generate_agent_parameters( parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters, agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data, node_data=self._node_data,
strategy=strategy, strategy=strategy,
) )
parameters_for_log = self._generate_agent_parameters( parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters, agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data, node_data=self._node_data,
for_log=True, for_log=True,
strategy=strategy, strategy=strategy,
) )
@ -105,59 +136,39 @@ class AgentNode(ToolNode):
credentials=credentials, credentials=credentials,
) )
except Exception as e: except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
error=f"Failed to invoke agent: {str(e)}", error=str(error),
) )
) )
return return
try: try:
# convert tool messages
agent_thoughts: list = []
thought_log_message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LOG,
message=ToolInvokeMessage.LogMessage(
id=str(uuid.uuid4()),
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
parent_id=None,
error=None,
status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"parameters": parameters_for_log,
"thought_process": "Agent strategy execution started",
},
metadata={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
},
),
)
def enhanced_message_stream():
yield thought_log_message
yield from message_stream
yield from self._transform_message( yield from self._transform_message(
message_stream, messages=message_stream,
{ tool_info={
"icon": self.agent_strategy_icon, "icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
}, },
parameters_for_log, parameters_for_log=parameters_for_log,
agent_thoughts, user_id=self.user_id,
tenant_id=self.tenant_id,
node_type=self.type_,
node_id=self.node_id,
node_execution_id=self.id,
) )
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
error=f"Failed to transform agent message: {str(e)}", error=str(transform_error),
) )
) )
@ -194,7 +205,7 @@ class AgentNode(ToolNode):
if agent_input.type == "variable": if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None: if variable is None:
raise ValueError(f"Variable {agent_input.value} does not exist") raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}: elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template, # variable_pool.convert_template expects a string template,
@ -216,7 +227,7 @@ class AgentNode(ToolNode):
except json.JSONDecodeError: except json.JSONDecodeError:
parameter_value = parameter_value parameter_value = parameter_value
else: else:
raise ValueError(f"Unknown agent input type '{agent_input.type}'") raise AgentInputTypeError(agent_input.type)
value = parameter_value value = parameter_value
if parameter.type == "array[tools]": if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value) value = cast(list[dict[str, Any]], value)
@ -259,7 +270,7 @@ class AgentNode(ToolNode):
) )
extra = tool.get("extra", {}) extra = tool.get("extra", {})
runtime_variable_pool = variable_pool if self.node_data.version != "1" else None runtime_variable_pool = variable_pool if self._node_data.version != "1" else None
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
) )
@ -343,19 +354,14 @@ class AgentNode(ToolNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: BaseNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = AgentNodeData.model_validate(node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = cast(AgentNodeData, node_data)
result: dict[str, Any] = {} result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters: for parameter_name in typed_node_data.agent_parameters:
input = node_data.agent_parameters[parameter_name] input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]: if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors: for selector in selectors:
@ -380,7 +386,7 @@ class AgentNode(ToolNode):
plugin plugin
for plugin in plugins for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self.node_data).agent_strategy_provider_name == cast(AgentNodeData, self._node_data).agent_strategy_provider_name
) )
icon = current_plugin.declaration.icon icon = current_plugin.declaration.icon
except StopIteration: except StopIteration:
@ -448,3 +454,236 @@ class AgentNode(ToolNode):
return tools return tools
else: else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value] return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if message.message.json_object is not None:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, File)
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
json_output.extend(json)
else:
json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

@ -0,0 +1,124 @@
from typing import Optional
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
strategy_name,
provider_name,
)
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
self.original_error = original_error
super().__init__(message)
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: Optional[str] = None):
self.parameter_name = parameter_name
super().__init__(message)
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: Optional[str] = None):
self.variable_name = variable_name
super().__init__(message)
class AgentVariableNotFoundError(AgentVariableError):
"""Exception raised when a variable is not found in the variable pool."""
def __init__(self, variable_name: str):
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
class AgentInputTypeError(AgentNodeError):
"""Exception raised when an unknown agent input type is encountered."""
def __init__(self, input_type: str):
super().__init__(f"Unknown agent input type '{input_type}'")
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: Optional[str] = None):
self.file_id = file_id
super().__init__(message)
class ToolFileNotFoundError(ToolFileError):
"""Exception raised when a tool file is not found."""
def __init__(self, file_id: str):
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
self.original_error = original_error
super().__init__(message)
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: Optional[str] = None):
self.conversation_id = conversation_id
super().__init__(message)
class AgentVariableTypeError(AgentNodeError):
"""Exception raised when a variable has an unexpected type."""
def __init__(
self,
message: str,
variable_name: Optional[str] = None,
expected_type: Optional[str] = None,
actual_type: Optional[str] = None,
):
self.variable_name = variable_name
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, cast from typing import Any, Optional, cast
from core.variables import ArrayFileSegment, FileSegment from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
@ -12,14 +12,37 @@ from core.workflow.nodes.answer.entities import (
VarGenerateRouteChunk, VarGenerateRouteChunk,
) )
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode[AnswerNodeData]): class AnswerNode(BaseNode):
_node_data_cls = AnswerNodeData
_node_type = NodeType.ANSWER _node_type = NodeType.ANSWER
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -30,7 +53,7 @@ class AnswerNode(BaseNode[AnswerNodeData]):
:return: :return:
""" """
# generate routes # generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data) generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
answer = "" answer = ""
files = [] files = []
@ -60,16 +83,12 @@ class AnswerNode(BaseNode[AnswerNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: AnswerNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = AnswerNodeData.model_validate(node_data)
:param graph_config: graph config
:param node_id: node id variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
:param node_data: node data
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors() variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {} variable_mapping = {}

@ -122,13 +122,13 @@ class RetryConfig(BaseModel):
class BaseNodeData(ABC, BaseModel): class BaseNodeData(ABC, BaseModel):
title: str title: str
desc: Optional[str] = None desc: Optional[str] = None
version: str = "1"
error_strategy: Optional[ErrorStrategy] = None error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None default_value: Optional[list[DefaultValue]] = None
version: str = "1"
retry_config: RetryConfig = RetryConfig() retry_config: RetryConfig = RetryConfig()
@property @property
def default_value_dict(self): def default_value_dict(self) -> dict[str, Any]:
if self.default_value: if self.default_value:
return {item.key: item.value for item in self.default_value} return {item.key: item.value for item in self.default_value}
return {} return {}

@ -1,28 +1,22 @@
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from .entities import BaseNodeData
if TYPE_CHECKING: if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]): class BaseNode:
_node_data_cls: type[GenericNodeData]
_node_type: ClassVar[NodeType] _node_type: ClassVar[NodeType]
def __init__( def __init__(
@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]):
self.node_id = node_id self.node_id = node_id
node_data = self._node_data_cls.model_validate(config.get("data", {})) @abstractmethod
self.node_data = node_data def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod @abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
@ -130,9 +124,9 @@ class BaseNode(Generic[GenericNodeData]):
if not node_id: if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.") raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {})) # Pass raw dict data instead of creating NodeData instance
data = cls._extract_variable_selector_to_variable_mapping( data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
) )
return data return data
@ -142,32 +136,16 @@ class BaseNode(Generic[GenericNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: GenericNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {} return {}
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {} return {}
@property @property
def node_type(self) -> NodeType: def type_(self) -> NodeType:
"""
Get node type
:return:
"""
return self._node_type return self._node_type
@classmethod @classmethod
@ -181,19 +159,68 @@ class BaseNode(Generic[GenericNodeData]):
raise NotImplementedError("subclasses of BaseNode must implement `version` method.") raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@property @property
def should_continue_on_error(self) -> bool: def continue_on_error(self) -> bool:
"""judge if should continue on error return False
Returns: @property
bool: if should continue on error def retry(self) -> bool:
""" return False
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
# Abstract methods that subclasses must implement to provide access
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
"""Get the error strategy for this node."""
...
@abstractmethod
def _get_retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
...
@abstractmethod
def _get_title(self) -> str:
"""Get the node title."""
...
@abstractmethod
def _get_description(self) -> Optional[str]:
"""Get the node description."""
...
@abstractmethod
def _get_default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
...
@abstractmethod
def get_base_node_data(self) -> BaseNodeData:
"""Get the BaseNodeData object for this node."""
...
# Public interface properties that delegate to abstract methods
@property @property
def should_retry(self) -> bool: def error_strategy(self) -> Optional[ErrorStrategy]:
"""judge if should retry """Get the error strategy for this node."""
return self._get_error_strategy()
Returns: @property
bool: if should retry def retry_config(self) -> RetryConfig:
""" """Get the retry configuration for this node."""
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE return self._get_retry_config()
@property
def title(self) -> str:
"""Get the node title."""
return self._get_title()
@property
def description(self) -> Optional[str]:
"""Get the node description."""
return self._get_description()
@property
def default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
return self._get_default_value_dict()

@ -11,8 +11,9 @@ from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.code.entities import CodeNodeData from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .exc import ( from .exc import (
CodeNodeError, CodeNodeError,
@ -21,10 +22,32 @@ from .exc import (
) )
class CodeNode(BaseNode[CodeNodeData]): class CodeNode(BaseNode):
_node_data_cls = CodeNodeData
_node_type = NodeType.CODE _node_type = NodeType.CODE
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
""" """
@ -47,12 +70,12 @@ class CodeNode(BaseNode[CodeNodeData]):
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get code language # Get code language
code_language = self.node_data.code_language code_language = self._node_data.code_language
code = self.node_data.code code = self._node_data.code
# Get variables # Get variables
variables = {} variables = {}
for variable_selector in self.node_data.variables: for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment): if isinstance(variable, ArrayFileSegment):
@ -68,7 +91,7 @@ class CodeNode(BaseNode[CodeNodeData]):
) )
# Transform result # Transform result
result = self._transform_result(result=result, output_schema=self.node_data.outputs) result = self._transform_result(result=result, output_schema=self._node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e: except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -334,16 +357,20 @@ class CodeNode(BaseNode[CodeNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: CodeNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = CodeNodeData.model_validate(node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return { return {
node_id + "." + variable_selector.variable: variable_selector.value_selector node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables for variable_selector in typed_node_data.variables
} }
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

@ -5,7 +5,7 @@ import logging
import os import os
import tempfile import tempfile
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, cast from typing import Any, Optional, cast
import chardet import chardet
import docx import docx
@ -28,7 +28,8 @@ from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import DocumentExtractorNodeData from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
@ -36,21 +37,43 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): class DocumentExtractorNode(BaseNode):
""" """
Extracts text content from various file types. Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files. Supports plain text, PDF, and DOC/DOCX files.
""" """
_node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR _node_type = NodeType.DOCUMENT_EXTRACTOR
_node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
def _run(self): def _run(self):
variable_selector = self.node_data.variable_selector variable_selector = self._node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None: if variable is None:
@ -97,16 +120,12 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: DocumentExtractorNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
:param graph_config: graph config
:param node_id: node id return {node_id + ".files": typed_node_data.variable_selector}
:param node_data: node data
:return:
"""
return {node_id + ".files": node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:

@ -1,14 +1,40 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.end.entities import EndNodeData from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import ErrorStrategy, NodeType
class EndNode(BaseNode[EndNodeData]): class EndNode(BaseNode):
_node_data_cls = EndNodeData
_node_type = NodeType.END _node_type = NodeType.END
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = EndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -18,7 +44,7 @@ class EndNode(BaseNode[EndNodeData]):
Run node Run node
:return: :return:
""" """
output_variables = self.node_data.outputs output_variables = self._node_data.outputs
outputs = {} outputs = {}
for variable_selector in output_variables: for variable_selector in output_variables:

@ -35,7 +35,3 @@ class ErrorStrategy(StrEnum):
class FailBranchSourceHandle(StrEnum): class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch" FAILED = "fail-branch"
SUCCESS = "success-branch" SUCCESS = "success-branch"
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE

@ -11,7 +11,8 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.http_request.executor import Executor from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser from core.workflow.utils import variable_template_parser
from factories import file_factory from factories import file_factory
@ -32,10 +33,32 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HttpRequestNode(BaseNode[HttpRequestNodeData]): class HttpRequestNode(BaseNode):
_node_data_cls = HttpRequestNodeData
_node_type = NodeType.HTTP_REQUEST _node_type = NodeType.HTTP_REQUEST
_node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
return { return {
@ -69,8 +92,8 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
process_data = {} process_data = {}
try: try:
http_executor = Executor( http_executor = Executor(
node_data=self.node_data, node_data=self._node_data,
timeout=self._get_request_timeout(self.node_data), timeout=self._get_request_timeout(self._node_data),
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0, max_retries=0,
) )
@ -78,7 +101,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
response = http_executor.invoke() response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response) files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and (self.should_continue_on_error or self.should_retry): if not response.response.is_success and (self.continue_on_error or self.retry):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
outputs={ outputs={
@ -131,15 +154,18 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: HttpRequestNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = HttpRequestNodeData.model_validate(node_data)
selectors: list[VariableSelector] = [] selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(node_data.url) selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params) selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
if node_data.body: if typed_node_data.body:
body_type = node_data.body.type body_type = typed_node_data.body.type
data = node_data.body.data data = typed_node_data.body.data
match body_type: match body_type:
case "binary": case "binary":
if len(data) != 1: if len(data) != 1:
@ -217,3 +243,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
files.append(file) files.append(file)
return ArrayFileSegment(value=files) return ArrayFileSegment(value=files)
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Literal from typing import Any, Literal, Optional
from typing_extensions import deprecated from typing_extensions import deprecated
@ -7,16 +7,39 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(BaseNode[IfElseNodeData]): class IfElseNode(BaseNode):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE _node_type = NodeType.IF_ELSE
_node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -36,8 +59,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
condition_processor = ConditionProcessor() condition_processor = ConditionProcessor()
try: try:
# Check if the new cases structure is used # Check if the new cases structure is used
if self.node_data.cases: if self._node_data.cases:
for case in self.node_data.cases: for case in self._node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions( input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions, conditions=case.conditions,
@ -63,8 +86,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
input_conditions, group_result, final_result = _should_not_use_old_function( input_conditions, group_result, final_result = _should_not_use_old_function(
condition_processor=condition_processor, condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
conditions=self.node_data.conditions or [], conditions=self._node_data.conditions or [],
operator=self.node_data.logical_operator or "and", operator=self._node_data.logical_operator or "and",
) )
selected_case_id = "true" if final_result else "false" selected_case_id = "true" if final_result else "false"
@ -98,10 +121,13 @@ class IfElseNode(BaseNode[IfElseNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: IfElseNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IfElseNodeData.model_validate(node_data)
var_mapping: dict[str, list[str]] = {} var_mapping: dict[str, list[str]] = {}
for case in node_data.cases or []: for case in typed_node_data.cases or []:
for condition in case.conditions: for condition in case.conditions:
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
var_mapping[key] = condition.variable_selector var_mapping[key] = condition.variable_selector

@ -36,7 +36,8 @@ from core.workflow.graph_engine.entities.event import (
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment from factories.variable_factory import build_segment
@ -56,14 +57,36 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class IterationNode(BaseNode[IterationNodeData]): class IterationNode(BaseNode):
""" """
Iteration Node. Iteration Node.
""" """
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION _node_type = NodeType.ITERATION
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return { return {
@ -83,10 +106,10 @@ class IterationNode(BaseNode[IterationNodeData]):
""" """
Run the node. Run the node.
""" """
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable: if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
@ -116,10 +139,10 @@ class IterationNode(BaseNode[IterationNodeData]):
graph_config = self.graph_config graph_config = self.graph_config
if not self.node_data.start_node_id: if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self.node_data.start_node_id root_node_id = self._node_data.start_node_id
# init graph # init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
@ -161,8 +184,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunStartedEvent( yield IterationRunStartedEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
metadata={"iterator_length": len(iterator_list_value)}, metadata={"iterator_length": len(iterator_list_value)},
@ -172,8 +195,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent( yield IterationRunNextEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
index=0, index=0,
pre_iteration_output=None, pre_iteration_output=None,
duration=None, duration=None,
@ -181,11 +204,11 @@ class IterationNode(BaseNode[IterationNodeData]):
iter_run_map: dict[str, float] = {} iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value) outputs: list[Any] = [None] * len(iterator_list_value)
try: try:
if self.node_data.is_parallel: if self._node_data.is_parallel:
futures: list[Future] = [] futures: list[Future] = []
q: Queue = Queue() q: Queue = Queue()
thread_pool = GraphEngineThreadPool( thread_pool = GraphEngineThreadPool(
max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
) )
for index, item in enumerate(iterator_list_value): for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit( future: Future = thread_pool.submit(
@ -242,7 +265,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph=iteration_graph, iteration_graph=iteration_graph,
iter_run_map=iter_run_map, iter_run_map=iter_run_map,
) )
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None] outputs = [output for output in outputs if output is not None]
# Flatten the list of lists # Flatten the list of lists
@ -253,8 +276,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunSucceededEvent( yield IterationRunSucceededEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": outputs},
@ -278,8 +301,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent( yield IterationRunFailedEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": outputs},
@ -305,21 +328,17 @@ class IterationNode(BaseNode[IterationNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: IterationNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = IterationNodeData.model_validate(node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping: dict[str, Sequence[str]] = { variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": node_data.iterator_selector, f"{node_id}.input_selector": typed_node_data.iterator_selector,
} }
# init graph # init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not iteration_graph: if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found") raise IterationGraphNotFoundError("iteration graph not found")
@ -375,7 +394,7 @@ class IterationNode(BaseNode[IterationNodeData]):
""" """
if not isinstance(event, BaseNodeEvent): if not isinstance(event, BaseNodeEvent):
return event return event
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id event.parallel_mode_run_id = parallel_mode_run_id
iter_metadata = { iter_metadata = {
@ -438,12 +457,12 @@ class IterationNode(BaseNode[IterationNodeData]):
elif isinstance(event, BaseGraphEvent): elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent): if isinstance(event, GraphRunFailedEvent):
# iteration run failed # iteration run failed
if self.node_data.is_parallel: if self._node_data.is_parallel:
yield IterationRunFailedEvent( yield IterationRunFailedEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
@ -456,8 +475,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent( yield IterationRunFailedEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": outputs},
@ -478,7 +497,7 @@ class IterationNode(BaseNode[IterationNodeData]):
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
) )
if isinstance(event, NodeRunFailedEvent): if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent( yield NodeInIterationFailedEvent(
**metadata_event.model_dump(), **metadata_event.model_dump(),
) )
@ -491,15 +510,15 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent( yield IterationRunNextEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
index=next_index, index=next_index,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None, pre_iteration_output=None,
duration=duration, duration=duration,
) )
return return
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent( yield NodeInIterationFailedEvent(
**metadata_event.model_dump(), **metadata_event.model_dump(),
) )
@ -512,15 +531,15 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent( yield IterationRunNextEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
index=next_index, index=next_index,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None, pre_iteration_output=None,
duration=duration, duration=duration,
) )
return return
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
yield NodeInIterationFailedEvent( yield NodeInIterationFailedEvent(
**metadata_event.model_dump(), **metadata_event.model_dump(),
) )
@ -531,12 +550,12 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.remove([node_id]) variable_pool.remove([node_id])
# iteration run failed # iteration run failed
if self.node_data.is_parallel: if self._node_data.is_parallel:
yield IterationRunFailedEvent( yield IterationRunFailedEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
@ -549,8 +568,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent( yield IterationRunFailedEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": outputs}, outputs={"output": outputs},
@ -569,7 +588,7 @@ class IterationNode(BaseNode[IterationNodeData]):
return return
yield metadata_event yield metadata_event
current_output_segment = variable_pool.get(self.node_data.output_selector) current_output_segment = variable_pool.get(self._node_data.output_selector)
if current_output_segment is None: if current_output_segment is None:
raise IterationNodeError("iteration output selector not found") raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value current_iteration_output = current_output_segment.value
@ -588,8 +607,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent( yield IterationRunNextEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
index=next_index, index=next_index,
parallel_mode_run_id=parallel_mode_run_id, parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=current_iteration_output or None, pre_iteration_output=current_iteration_output or None,
@ -601,8 +620,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent( yield IterationRunFailedEvent(
iteration_id=self.id, iteration_id=self.id,
iteration_node_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type, iteration_node_type=self.type_,
iteration_node_data=self.node_data, iteration_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs={"output": None}, outputs={"output": None},

@ -1,18 +1,44 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.iteration.entities import IterationStartNodeData from core.workflow.nodes.iteration.entities import IterationStartNodeData
class IterationStartNode(BaseNode[IterationStartNodeData]): class IterationStartNode(BaseNode):
""" """
Iteration Start Node. Iteration Start Node.
""" """
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START _node_type = NodeType.ITERATION_START
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = IterationStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -1,10 +1,10 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Literal, Optional from typing import Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import VisionConfig from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
class RerankingModelConfig(BaseModel): class RerankingModelConfig(BaseModel):
@ -56,17 +56,6 @@ class MultipleRetrievalConfig(BaseModel):
weights: Optional[WeightedScoreConfig] = None weights: Optional[WeightedScoreConfig] = None
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class SingleRetrievalConfig(BaseModel): class SingleRetrievalConfig(BaseModel):
""" """
Single Retrieval Config. Single Retrieval Config.
@ -129,7 +118,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None metadata_model_config: ModelConfig
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)

@ -4,7 +4,7 @@ import re
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy import cast as sqlalchemy_cast
@ -15,20 +15,31 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities.model_entities import ModelFeature, ModelType PromptMessageRole,
)
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelType,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment from core.variables import (
StringSegment,
)
from core.variables.segments import ArrayObjectSegment from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
)
from core.workflow.nodes.knowledge_retrieval.template_prompts import ( from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1, METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2, METADATA_FILTER_ASSISTANT_PROMPT_2,
@ -38,7 +49,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3, METADATA_FILTER_USER_PROMPT_3,
) )
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -46,7 +58,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from services.feature_service import FeatureService from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData, ModelConfig from .entities import KnowledgeRetrievalNodeData
from .exc import ( from .exc import (
InvalidModelTypeError, InvalidModelTypeError,
KnowledgeRetrievalNodeError, KnowledgeRetrievalNodeError,
@ -56,6 +68,10 @@ from .exc import (
ModelQuotaExceededError, ModelQuotaExceededError,
) )
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
default_retrieval_model = { default_retrieval_model = {
@ -67,18 +83,76 @@ default_retrieval_model = {
} }
class KnowledgeRetrievalNode(LLMNode): class KnowledgeRetrievalNode(BaseNode):
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
_node_type = NodeType.KNOWLEDGE_RETRIEVAL _node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls): def version(cls):
return "1" return "1"
def _run(self) -> NodeRunResult: # type: ignore def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables # extract variables
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector) variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment): if not isinstance(variable, StringSegment):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
@ -119,7 +193,7 @@ class KnowledgeRetrievalNode(LLMNode):
# retrieve knowledge # retrieve knowledge
try: try:
results = self._fetch_dataset_retriever(node_data=node_data, query=query) results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)} outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -435,20 +509,15 @@ class KnowledgeRetrievalNode(LLMNode):
# get all metadata field # get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config # get metadata model instance and fetch model config
metadata_model_config = node_data.metadata_model_config model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
if metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self.get_model_config(metadata_model_config)
# fetch prompt messages # fetch prompt messages
prompt_template = self._get_prompt_template( prompt_template = self._get_prompt_template(
node_data=node_data, node_data=node_data,
metadata_fields=all_metadata_fields, metadata_fields=all_metadata_fields,
query=query or "", query=query or "",
) )
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
sys_query=query, sys_query=query,
memory=None, memory=None,
@ -458,16 +527,23 @@ class KnowledgeRetrievalNode(LLMNode):
vision_detail=node_data.vision.configs.detail, vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[], jinja2_variables=[],
tenant_id=self.tenant_id,
) )
result_text = "" result_text = ""
try: try:
# handle invoke result # handle invoke result
generator = self._invoke_llm( generator = LLMNode.invoke_llm(
node_data_model=node_data.metadata_model_config, # type: ignore node_data_model=node_data.metadata_model_config,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop, stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
) )
for event in generator: for event in generator:
@ -557,17 +633,13 @@ class KnowledgeRetrievalNode(LLMNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: KnowledgeRetrievalNodeData, # type: ignore node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping = {} variable_mapping = {}
variable_mapping[node_id + ".query"] = node_data.query_variable_selector variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
return variable_mapping return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
@ -629,7 +701,7 @@ class KnowledgeRetrievalNode(LLMNode):
) )
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore model_mode = ModelMode(node_data.metadata_model_config.mode)
input_text = query input_text = query
prompt_messages: list[LLMNodeChatModelMessage] = [] prompt_messages: list[LLMNodeChatModelMessage] = []

@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Union from typing import Any, Literal, Optional, Union
from core.file import File from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -7,16 +7,39 @@ from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import ListOperatorNodeData from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
class ListOperatorNode(BaseNode[ListOperatorNodeData]): class ListOperatorNode(BaseNode):
_node_data_cls = ListOperatorNodeData
_node_type = NodeType.LIST_OPERATOR _node_type = NodeType.LIST_OPERATOR
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = ListOperatorNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -26,9 +49,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
process_data: dict[str, list] = {} process_data: dict[str, list] = {}
outputs: dict[str, Any] = {} outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
if variable is None: if variable is None:
error_message = f"Variable not found for selector: {self.node_data.variable}" error_message = f"Variable not found for selector: {self._node_data.variable}"
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
) )
@ -48,7 +71,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
) )
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = ( error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment" "or ArrayStringSegment"
) )
return NodeRunResult( return NodeRunResult(
@ -64,19 +87,19 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
try: try:
# Filter # Filter
if self.node_data.filter_by.enabled: if self._node_data.filter_by.enabled:
variable = self._apply_filter(variable) variable = self._apply_filter(variable)
# Extract # Extract
if self.node_data.extract_by.enabled: if self._node_data.extract_by.enabled:
variable = self._extract_slice(variable) variable = self._extract_slice(variable)
# Order # Order
if self.node_data.order_by.enabled: if self._node_data.order_by.enabled:
variable = self._apply_order(variable) variable = self._apply_order(variable)
# Slice # Slice
if self.node_data.limit.enabled: if self._node_data.limit.enabled:
variable = self._apply_slice(variable) variable = self._apply_slice(variable)
outputs = { outputs = {
@ -104,7 +127,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
filter_func: Callable[[Any], bool] filter_func: Callable[[Any], bool]
result: list[Any] = [] result: list[Any] = []
for condition in self.node_data.filter_by.conditions: for condition in self._node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment): if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str): if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
@ -137,14 +160,14 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment): if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value) result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment): elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value) result = _order_number(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment): elif isinstance(variable, ArrayFileSegment):
result = _order_file( result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
) )
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
return variable return variable
@ -152,13 +175,13 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
def _apply_slice( def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
result = variable.value[: self.node_data.limit.size] result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result}) return variable.model_copy(update={"value": result})
def _extract_slice( def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1: if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}") raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
value -= 1 value -= 1

@ -1,4 +1,4 @@
from collections.abc import Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None memory: Optional[MemoryConfig] = None
context: ContextConfig context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name. # We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")

@ -59,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import ( from core.workflow.nodes.event import (
ModelInvokeCompletedEvent, ModelInvokeCompletedEvent,
NodeEvent, NodeEvent,
@ -90,17 +91,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]): class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM _node_type = NodeType.LLM
_node_data: LLMNodeData
# Instance attributes specific to LLMNode. # Instance attributes specific to LLMNode.
# Output variable for file # Output variable for file
_file_outputs: list["File"] _file_outputs: list["File"]
@ -138,6 +138,27 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
self._llm_file_saver = llm_file_saver self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -152,13 +173,13 @@ class LLMNode(BaseNode[LLMNodeData]):
try: try:
# init messages template # init messages template
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
# fetch variables and fetch values from variable pool # fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data) inputs = self._fetch_inputs(node_data=self._node_data)
# fetch jinja2 inputs # fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
# merge inputs # merge inputs
inputs.update(jinja_inputs) inputs.update(jinja_inputs)
@ -169,9 +190,9 @@ class LLMNode(BaseNode[LLMNodeData]):
files = ( files = (
llm_utils.fetch_files( llm_utils.fetch_files(
variable_pool=variable_pool, variable_pool=variable_pool,
selector=self.node_data.vision.configs.variable_selector, selector=self._node_data.vision.configs.variable_selector,
) )
if self.node_data.vision.enabled if self._node_data.vision.enabled
else [] else []
) )
@ -179,7 +200,7 @@ class LLMNode(BaseNode[LLMNodeData]):
node_inputs["#files#"] = [file.to_dict() for file in files] node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value # fetch context value
generator = self._fetch_context(node_data=self.node_data) generator = self._fetch_context(node_data=self._node_data)
context = None context = None
for event in generator: for event in generator:
if isinstance(event, RunRetrieverResourceEvent): if isinstance(event, RunRetrieverResourceEvent):
@ -189,44 +210,54 @@ class LLMNode(BaseNode[LLMNodeData]):
node_inputs["#context#"] = context node_inputs["#context#"] = context
# fetch model config # fetch model config
model_instance, model_config = self._fetch_model_config(self.node_data.model) model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self._node_data.model,
tenant_id=self.tenant_id,
)
# fetch memory # fetch memory
memory = llm_utils.fetch_memory( memory = llm_utils.fetch_memory(
variable_pool=variable_pool, variable_pool=variable_pool,
app_id=self.app_id, app_id=self.app_id,
node_data_memory=self.node_data.memory, node_data_memory=self._node_data.memory,
model_instance=model_instance, model_instance=model_instance,
) )
query = None query = None
if self.node_data.memory: if self._node_data.memory:
query = self.node_data.memory.query_prompt_template query = self._node_data.memory.query_prompt_template
if not query and ( if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
): ):
query = query_variable.text query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query, sys_query=query,
sys_files=files, sys_files=files,
context=context, context=context,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
prompt_template=self.node_data.prompt_template, prompt_template=self._node_data.prompt_template,
memory_config=self.node_data.memory, memory_config=self._node_data.memory,
vision_enabled=self.node_data.vision.enabled, vision_enabled=self._node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail, vision_detail=self._node_data.vision.configs.detail,
variable_pool=variable_pool, variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables, jinja2_variables=self._node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
) )
# handle invoke result # handle invoke result
generator = self._invoke_llm( generator = LLMNode.invoke_llm(
node_data_model=self.node_data.model, node_data_model=self._node_data.model,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop, stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self._node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
) )
structured_output: LLMStructuredOutput | None = None structured_output: LLMStructuredOutput | None = None
@ -296,12 +327,19 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
) )
def _invoke_llm( @staticmethod
self, def invoke_llm(
*,
node_data_model: ModelConfig, node_data_model: ModelConfig,
model_instance: ModelInstance, model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage], prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Optional[Mapping[str, Any]] = None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema( model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials node_data_model.name, model_instance.credentials
@ -309,8 +347,10 @@ class LLMNode(BaseNode[LLMNodeData]):
if not model_schema: if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}") raise ValueError(f"Model schema not found for {node_data_model.name}")
if self.node_data.structured_output_enabled: if structured_output_enabled:
output_schema = self._fetch_structured_output_schema() output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
)
invoke_result = invoke_llm_with_structured_output( invoke_result = invoke_llm_with_structured_output(
provider=model_instance.provider, provider=model_instance.provider,
model_schema=model_schema, model_schema=model_schema,
@ -320,7 +360,7 @@ class LLMNode(BaseNode[LLMNodeData]):
model_parameters=node_data_model.completion_params, model_parameters=node_data_model.completion_params,
stop=list(stop or []), stop=list(stop or []),
stream=True, stream=True,
user=self.user_id, user=user_id,
) )
else: else:
invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
@ -328,17 +368,31 @@ class LLMNode(BaseNode[LLMNodeData]):
model_parameters=node_data_model.completion_params, model_parameters=node_data_model.completion_params,
stop=list(stop or []), stop=list(stop or []),
stream=True, stream=True,
user=self.user_id, user=user_id,
) )
return self._handle_invoke_result(invoke_result=invoke_result) return LLMNode.handle_invoke_result(
invoke_result=invoke_result,
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
)
def _handle_invoke_result( @staticmethod
self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] def handle_invoke_result(
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
# For blocking mode # For blocking mode
if isinstance(invoke_result, LLMResult): if isinstance(invoke_result, LLMResult):
event = self._handle_blocking_result(invoke_result=invoke_result) event = LLMNode.handle_blocking_result(
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
)
yield event yield event
return return
@ -356,11 +410,13 @@ class LLMNode(BaseNode[LLMNodeData]):
yield result yield result
if isinstance(result, LLMResultChunk): if isinstance(result, LLMResultChunk):
contents = result.delta.message.content contents = result.delta.message.content
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=contents,
file_saver=file_saver,
file_outputs=file_outputs,
):
full_text_buffer.write(text_part) full_text_buffer.write(text_part)
yield RunStreamChunkEvent( yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
chunk_content=text_part, from_variable_selector=[self.node_id, "text"]
)
# Update the whole metadata # Update the whole metadata
if not model and result.model: if not model and result.model:
@ -378,7 +434,8 @@ class LLMNode(BaseNode[LLMNodeData]):
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason) yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
def _image_file_to_markdown(self, file: "File", /): @staticmethod
def _image_file_to_markdown(file: "File", /):
text_chunk = f"![]({file.generate_url()})" text_chunk = f"![]({file.generate_url()})"
return text_chunk return text_chunk
@ -539,11 +596,14 @@ class LLMNode(BaseNode[LLMNodeData]):
return None return None
@staticmethod
def _fetch_model_config( def _fetch_model_config(
self, node_data_model: ModelConfig *,
node_data_model: ModelConfig,
tenant_id: str,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config( model, model_config_with_cred = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model tenant_id=tenant_id, node_data_model=node_data_model
) )
completion_params = model_config_with_cred.parameters completion_params = model_config_with_cred.parameters
@ -556,8 +616,8 @@ class LLMNode(BaseNode[LLMNodeData]):
node_data_model.completion_params = completion_params node_data_model.completion_params = completion_params
return model, model_config_with_cred return model, model_config_with_cred
def _fetch_prompt_messages( @staticmethod
self, def fetch_prompt_messages(
*, *,
sys_query: str | None = None, sys_query: str | None = None,
sys_files: Sequence["File"], sys_files: Sequence["File"],
@ -570,13 +630,14 @@ class LLMNode(BaseNode[LLMNodeData]):
vision_detail: ImagePromptMessageContent.DETAIL, vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool, variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector], jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list): if isinstance(prompt_template, list):
# For chat model # For chat model
prompt_messages.extend( prompt_messages.extend(
self._handle_list_messages( LLMNode.handle_list_messages(
messages=prompt_template, messages=prompt_template,
context=context, context=context,
jinja2_variables=jinja2_variables, jinja2_variables=jinja2_variables,
@ -602,7 +663,7 @@ class LLMNode(BaseNode[LLMNodeData]):
edition_type="basic", edition_type="basic",
) )
prompt_messages.extend( prompt_messages.extend(
self._handle_list_messages( LLMNode.handle_list_messages(
messages=[message], messages=[message],
context="", context="",
jinja2_variables=[], jinja2_variables=[],
@ -731,7 +792,7 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
model = ModelManager().get_model_instance( model = ModelManager().get_model_instance(
tenant_id=self.tenant_id, tenant_id=tenant_id,
model_type=ModelType.LLM, model_type=ModelType.LLM,
provider=model_config.provider, provider=model_config.provider,
model=model_config.model, model=model_config.model,
@ -750,10 +811,12 @@ class LLMNode(BaseNode[LLMNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: LLMNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
prompt_template = node_data.prompt_template # Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template = typed_node_data.prompt_template
variable_selectors = [] variable_selectors = []
if isinstance(prompt_template, list) and all( if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
@ -773,7 +836,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in variable_selectors: for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = node_data.memory memory = typed_node_data.memory
if memory and memory.query_prompt_template: if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser( query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template template=memory.query_prompt_template
@ -781,16 +844,16 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in query_variable_selectors: for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled: if typed_node_data.context.enabled:
variable_mapping["#context#"] = node_data.context.variable_selector variable_mapping["#context#"] = typed_node_data.context.variable_selector
if node_data.vision.enabled: if typed_node_data.vision.enabled:
variable_mapping["#files#"] = node_data.vision.configs.variable_selector variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if node_data.memory: if typed_node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
if node_data.prompt_config: if typed_node_data.prompt_config:
enable_jinja = False enable_jinja = False
if isinstance(prompt_template, list): if isinstance(prompt_template, list):
@ -803,7 +866,7 @@ class LLMNode(BaseNode[LLMNodeData]):
enable_jinja = True enable_jinja = True
if enable_jinja: if enable_jinja:
for variable_selector in node_data.prompt_config.jinja2_variables or []: for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
@ -835,8 +898,8 @@ class LLMNode(BaseNode[LLMNodeData]):
}, },
} }
def _handle_list_messages( @staticmethod
self, def handle_list_messages(
*, *,
messages: Sequence[LLMNodeChatModelMessage], messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str], context: Optional[str],
@ -897,9 +960,19 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent: @staticmethod
def handle_blocking_result(
*,
invoke_result: LLMResult,
saver: LLMFileSaver,
file_outputs: list["File"],
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO() buffer = io.StringIO()
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content): for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=invoke_result.message.content,
file_saver=saver,
file_outputs=file_outputs,
):
buffer.write(text_part) buffer.write(text_part)
return ModelInvokeCompletedEvent( return ModelInvokeCompletedEvent(
@ -908,7 +981,12 @@ class LLMNode(BaseNode[LLMNodeData]):
finish_reason=None, finish_reason=None,
) )
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File": @staticmethod
def save_multimodal_image_output(
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
) -> "File":
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins. """_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs: There are two kinds of multimodal outputs:
@ -918,26 +996,21 @@ class LLMNode(BaseNode[LLMNodeData]):
Currently, only image files are supported. Currently, only image files are supported.
""" """
# Inject the saver somehow...
_saver = self._llm_file_saver
# If this
if content.url != "": if content.url != "":
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE) saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
else: else:
saved_file = _saver.save_binary_string( saved_file = file_saver.save_binary_string(
data=base64.b64decode(content.base64_data), data=base64.b64decode(content.base64_data),
mime_type=content.mime_type, mime_type=content.mime_type,
file_type=FileType.IMAGE, file_type=FileType.IMAGE,
) )
self._file_outputs.append(saved_file)
return saved_file return saved_file
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
""" """
Fetch model schema Fetch model schema
""" """
model_name = self.node_data.model.name model_name = self._node_data.model.name
model_manager = ModelManager() model_manager = ModelManager()
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
@ -948,16 +1021,20 @@ class LLMNode(BaseNode[LLMNodeData]):
model_schema = model_type_instance.get_model_schema(model_name, model_credentials) model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema return model_schema
def _fetch_structured_output_schema(self) -> dict[str, Any]: @staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
""" """
Fetch the structured output schema from the node data. Fetch the structured output schema from the node data.
Returns: Returns:
dict[str, Any]: The structured output schema dict[str, Any]: The structured output schema
""" """
if not self.node_data.structured_output: if not structured_output:
raise LLMNodeError("Please provide a valid structured output schema") raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema: if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema") raise LLMNodeError("Please provide a valid structured output schema")
@ -969,9 +1046,12 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError: except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format") raise LLMNodeError("structured_output_schema is not valid JSON format")
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown( def _save_multimodal_output_and_convert_result_to_markdown(
self, *,
contents: str | list[PromptMessageContentUnionTypes] | None, contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller. """Convert intermediate prompt messages into strings and yield them to the caller.
@ -994,9 +1074,12 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(item, TextPromptMessageContent): if isinstance(item, TextPromptMessageContent):
yield item.data yield item.data
elif isinstance(item, ImagePromptMessageContent): elif isinstance(item, ImagePromptMessageContent):
file = self._save_multimodal_image_output(item) file = LLMNode.save_multimodal_image_output(
self._file_outputs.append(file) content=item,
yield self._image_file_to_markdown(file) file_saver=file_saver,
)
file_outputs.append(file)
yield LLMNode._image_file_to_markdown(file)
else: else:
logger.warning("unknown item type encountered, type=%s", type(item)) logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item) yield str(item)
@ -1004,6 +1087,14 @@ class LLMNode(BaseNode[LLMNodeData]):
logger.warning("unknown contents type encountered, type=%s", type(contents)) logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents) yield str(contents)
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
def _combine_message_content_with_role( def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole

@ -1,18 +1,44 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(BaseNode[LoopEndNodeData]): class LoopEndNode(BaseNode):
""" """
Loop End Node. Loop End Node.
""" """
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END _node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LoopEndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -3,7 +3,7 @@ import logging
import time import time
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config from configs import dify_config
from core.variables import ( from core.variables import (
@ -30,7 +30,8 @@ from core.workflow.graph_engine.entities.event import (
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor from core.workflow.utils.condition.processor import ConditionProcessor
@ -43,14 +44,36 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LoopNode(BaseNode[LoopNodeData]): class LoopNode(BaseNode):
""" """
Loop Node. Loop Node.
""" """
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP _node_type = NodeType.LOOP
_node_data: LoopNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -58,17 +81,17 @@ class LoopNode(BaseNode[LoopNodeData]):
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node.""" """Run the node."""
# Get inputs # Get inputs
loop_count = self.node_data.loop_count loop_count = self._node_data.loop_count
break_conditions = self.node_data.break_conditions break_conditions = self._node_data.break_conditions
logical_operator = self.node_data.logical_operator logical_operator = self._node_data.logical_operator
inputs = {"loop_count": loop_count} inputs = {"loop_count": loop_count}
if not self.node_data.start_node_id: if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self.node_id} not found") raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph # Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id) loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph: if not loop_graph:
raise ValueError("loop graph not found") raise ValueError("loop graph not found")
@ -78,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]):
# Initialize loop variables # Initialize loop variables
loop_variable_selectors = {} loop_variable_selectors = {}
if self.node_data.loop_variables: if self._node_data.loop_variables:
for loop_variable in self.node_data.loop_variables: for loop_variable in self._node_data.loop_variables:
value_processor = { value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value), "constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value), "variable": lambda var=loop_variable: variable_pool.get(var.value),
@ -127,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunStartedEvent( yield LoopRunStartedEvent(
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.type_,
loop_node_data=self.node_data, loop_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
metadata={"loop_length": loop_count}, metadata={"loop_length": loop_count},
@ -184,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunSucceededEvent( yield LoopRunSucceededEvent(
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.type_,
loop_node_data=self.node_data, loop_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs=self.node_data.outputs, outputs=self._node_data.outputs,
steps=loop_count, steps=loop_count,
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
@ -206,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
}, },
outputs=self.node_data.outputs, outputs=self._node_data.outputs,
inputs=inputs, inputs=inputs,
) )
) )
@ -217,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent( yield LoopRunFailedEvent(
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.type_,
loop_node_data=self.node_data, loop_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
steps=loop_count, steps=loop_count,
@ -320,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent( yield LoopRunFailedEvent(
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.type_,
loop_node_data=self.node_data, loop_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
steps=current_index, steps=current_index,
@ -351,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent( yield LoopRunFailedEvent(
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.type_,
loop_node_data=self.node_data, loop_node_data=self._node_data,
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
steps=current_index, steps=current_index,
@ -388,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]):
_outputs[loop_variable_key] = None _outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1 _outputs["loop_round"] = current_index + 1
self.node_data.outputs = _outputs self._node_data.outputs = _outputs
if check_break_result: if check_break_result:
return {"check_break_result": True} return {"check_break_result": True}
@ -400,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunNextEvent( yield LoopRunNextEvent(
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,
loop_node_type=self.node_type, loop_node_type=self.type_,
loop_node_data=self.node_data, loop_node_data=self._node_data,
index=next_index, index=next_index,
pre_loop_output=self.node_data.outputs, pre_loop_output=self._node_data.outputs,
) )
return {"check_break_result": False} return {"check_break_result": False}
@ -438,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: LoopNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = LoopNodeData.model_validate(node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping = {} variable_mapping = {}
# init graph # init graph
loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not loop_graph: if not loop_graph:
raise ValueError("loop graph not found") raise ValueError("loop graph not found")
@ -486,7 +505,7 @@ class LoopNode(BaseNode[LoopNodeData]):
variable_mapping.update(sub_node_variable_mapping) variable_mapping.update(sub_node_variable_mapping)
for loop_variable in node_data.loop_variables or []: for loop_variable in typed_node_data.loop_variables or []:
if loop_variable.value_type == "variable": if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type" assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping # add loop variable to variable mapping

@ -1,18 +1,44 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.loop.entities import LoopStartNodeData from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(BaseNode[LoopStartNodeData]): class LoopStartNode(BaseNode):
""" """
Loop Start Node. Loop Start Node.
""" """
_node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START _node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LoopStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -29,8 +29,9 @@ from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser from core.workflow.utils import variable_template_parser
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
@ -91,10 +92,31 @@ class ParameterExtractorNode(BaseNode):
Parameter Extractor Node. Parameter Extractor Node.
""" """
# FIXME: figure out why here is different from super class
_node_data_cls = ParameterExtractorNodeData # type: ignore
_node_type = NodeType.PARAMETER_EXTRACTOR _node_type = NodeType.PARAMETER_EXTRACTOR
_node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
_model_instance: Optional[ModelInstance] = None _model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None _model_config: Optional[ModelConfigWithCredentialsEntity] = None
@ -119,7 +141,7 @@ class ParameterExtractorNode(BaseNode):
""" """
Run the node. Run the node.
""" """
node_data = cast(ParameterExtractorNodeData, self.node_data) node_data = cast(ParameterExtractorNodeData, self._node_data)
variable = self.graph_runtime_state.variable_pool.get(node_data.query) variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else "" query = variable.text if variable else ""
@ -398,7 +420,7 @@ class ParameterExtractorNode(BaseNode):
""" """
Generate prompt engineering prompt. Generate prompt engineering prompt.
""" """
model_mode = ModelMode.value_of(data.model.mode) model_mode = ModelMode(data.model.mode)
if model_mode == ModelMode.COMPLETION: if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt( return self._generate_prompt_engineering_completion_prompt(
@ -694,7 +716,7 @@ class ParameterExtractorNode(BaseNode):
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000, max_token_limit: int = 2000,
) -> list[ChatModelMessage]: ) -> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode) model_mode = ModelMode(node_data.model.mode)
input_text = query input_text = query
memory_str = "" memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text instruction = variable_pool.convert_template(node_data.instruction or "").text
@ -721,7 +743,7 @@ class ParameterExtractorNode(BaseNode):
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000, max_token_limit: int = 2000,
): ):
model_mode = ModelMode.value_of(node_data.model.mode) model_mode = ModelMode(node_data.model.mode)
input_text = query input_text = query
memory_str = "" memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text instruction = variable_pool.convert_template(node_data.instruction or "").text
@ -827,19 +849,15 @@ class ParameterExtractorNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: ParameterExtractorNodeData, # type: ignore node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
:param graph_config: graph config
:param node_id: node id variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
:param node_data: node data
:return:
"""
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
if node_data.instruction: if typed_node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
for selector in selectors: for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector variable_mapping[selector.variable] = selector.value_selector

@ -1,6 +1,6 @@
import json import json
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
@ -11,8 +11,11 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import ( from core.workflow.nodes.llm import (
LLMNode, LLMNode,
@ -20,6 +23,7 @@ from core.workflow.nodes.llm import (
LLMNodeCompletionModelPromptTemplate, LLMNodeCompletionModelPromptTemplate,
llm_utils, llm_utils,
) )
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown from libs.json_in_md_parser import parse_and_check_json_markdown
@ -35,17 +39,77 @@ from .template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_3, QUESTION_CLASSIFIER_USER_PROMPT_3,
) )
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData # type: ignore class QuestionClassifierNode(BaseNode):
_node_type = NodeType.QUESTION_CLASSIFIER _node_type = NodeType.QUESTION_CLASSIFIER
_node_data: QuestionClassifierNodeData
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls): def version(cls):
return "1" return "1"
def _run(self): def _run(self):
node_data = cast(QuestionClassifierNodeData, self.node_data) node_data = cast(QuestionClassifierNodeData, self._node_data)
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
# extract variables # extract variables
@ -53,7 +117,10 @@ class QuestionClassifierNode(LLMNode):
query = variable.value if variable else None query = variable.value if variable else None
variables = {"query": query} variables = {"query": query}
# fetch model config # fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model) model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=node_data.model,
tenant_id=self.tenant_id,
)
# fetch memory # fetch memory
memory = llm_utils.fetch_memory( memory = llm_utils.fetch_memory(
variable_pool=variable_pool, variable_pool=variable_pool,
@ -91,7 +158,7 @@ class QuestionClassifierNode(LLMNode):
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error. # two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
sys_query="", sys_query="",
memory=memory, memory=memory,
@ -101,6 +168,7 @@ class QuestionClassifierNode(LLMNode):
vision_detail=node_data.vision.configs.detail, vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool, variable_pool=variable_pool,
jinja2_variables=[], jinja2_variables=[],
tenant_id=self.tenant_id,
) )
result_text = "" result_text = ""
@ -109,11 +177,17 @@ class QuestionClassifierNode(LLMNode):
try: try:
# handle invoke result # handle invoke result
generator = self._invoke_llm( generator = LLMNode.invoke_llm(
node_data_model=node_data.model, node_data_model=node_data.model,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop, stop=stop,
user_id=self.user_id,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
) )
for event in generator: for event in generator:
@ -183,23 +257,18 @@ class QuestionClassifierNode(LLMNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: Any, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
:param graph_config: graph config
:param node_id: node id variable_mapping = {"query": typed_node_data.query_variable_selector}
:param node_data: node data variable_selectors: list[VariableSelector] = []
:return: if typed_node_data.instruction:
""" variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
node_data = cast(QuestionClassifierNodeData, node_data)
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors()) variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors: for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
@ -265,7 +334,7 @@ class QuestionClassifierNode(LLMNode):
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000, max_token_limit: int = 2000,
): ):
model_mode = ModelMode.value_of(node_data.model.mode) model_mode = ModelMode(node_data.model.mode)
classes = node_data.classes classes = node_data.classes
categories = [] categories = []
for class_ in classes: for class_ in classes:

@ -1,15 +1,41 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.entities import StartNodeData
class StartNode(BaseNode[StartNodeData]): class StartNode(BaseNode):
_node_data_cls = StartNodeData
_node_type = NodeType.START _node_type = NodeType.START
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = StartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -6,16 +6,39 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): class TemplateTransformNode(BaseNode):
_node_data_cls = TemplateTransformNodeData
_node_type = NodeType.TEMPLATE_TRANSFORM _node_type = NodeType.TEMPLATE_TRANSFORM
_node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
""" """
@ -35,14 +58,14 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get variables # Get variables
variables = {} variables = {}
for variable_selector in self.node_data.variables: for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = value.to_object() if value else None variables[variable_name] = value.to_object() if value else None
# Run code # Run code
try: try:
result = CodeExecutor.execute_workflow_code_template( result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
) )
except CodeExecutionError as e: except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
@ -60,16 +83,12 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = TemplateTransformNodeData.model_validate(node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return { return {
node_id + "." + variable_selector.variable: variable_selector.value_selector node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables for variable_selector in typed_node_data.variables
} }

@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -19,10 +18,10 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
@ -37,14 +36,18 @@ from .exc import (
) )
class ToolNode(BaseNode[ToolNodeData]): class ToolNode(BaseNode):
""" """
Tool Node Tool Node
""" """
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL _node_type = NodeType.TOOL
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = ToolNodeData.model_validate(data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -54,7 +57,7 @@ class ToolNode(BaseNode[ToolNodeData]):
Run the tool node Run the tool node
""" """
node_data = cast(ToolNodeData, self.node_data) node_data = cast(ToolNodeData, self._node_data)
# fetch tool icon # fetch tool icon
tool_info = { tool_info = {
@ -67,9 +70,9 @@ class ToolNode(BaseNode[ToolNodeData]):
try: try:
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None
tool_runtime = ToolManager.get_workflow_tool_runtime( tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
) )
except ToolNodeError as e: except ToolNodeError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
@ -88,12 +91,12 @@ class ToolNode(BaseNode[ToolNodeData]):
parameters = self._generate_parameters( parameters = self._generate_parameters(
tool_parameters=tool_parameters, tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data, node_data=self._node_data,
) )
parameters_for_log = self._generate_parameters( parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters, tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data, node_data=self._node_data,
for_log=True, for_log=True,
) )
# get conversation id # get conversation id
@ -124,7 +127,14 @@ class ToolNode(BaseNode[ToolNodeData]):
try: try:
# convert tool messages # convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log) yield from self._transform_message(
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self.node_id,
)
except (PluginDaemonClientSideError, ToolInvokeError) as e: except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
@ -191,7 +201,9 @@ class ToolNode(BaseNode[ToolNodeData]):
messages: Generator[ToolInvokeMessage, None, None], messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any], tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any], parameters_for_log: dict[str, Any],
agent_thoughts: Optional[list] = None, user_id: str,
tenant_id: str,
node_id: str,
) -> Generator: ) -> Generator:
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
@ -199,8 +211,8 @@ class ToolNode(BaseNode[ToolNodeData]):
# transform message and handle file storage # transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages, messages=messages,
user_id=self.user_id, user_id=user_id,
tenant_id=self.tenant_id, tenant_id=tenant_id,
conversation_id=None, conversation_id=None,
) )
@ -208,9 +220,6 @@ class ToolNode(BaseNode[ToolNodeData]):
files: list[File] = [] files: list[File] = []
json: list[dict] = [] json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {} variables: dict[str, Any] = {}
for message in message_stream: for message in message_stream:
@ -243,7 +252,7 @@ class ToolNode(BaseNode[ToolNodeData]):
} }
file = file_factory.build_from_mapping( file = file_factory.build_from_mapping(
mapping=mapping, mapping=mapping,
tenant_id=self.tenant_id, tenant_id=tenant_id,
) )
files.append(file) files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB: elif message.type == ToolInvokeMessage.MessageType.BLOB:
@ -266,45 +275,36 @@ class ToolNode(BaseNode[ToolNodeData]):
files.append( files.append(
file_factory.build_from_mapping( file_factory.build_from_mapping(
mapping=mapping, mapping=mapping,
tenant_id=self.tenant_id, tenant_id=tenant_id,
) )
) )
elif message.type == ToolInvokeMessage.MessageType.TEXT: elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text text += message.message.text
yield RunStreamChunkEvent( yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
elif message.type == ToolInvokeMessage.MessageType.JSON: elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage) assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT: # JSON message handling for tool node
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if message.message.json_object is not None: if message.message.json_object is not None:
json.append(message.message.json_object) json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK: elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n" stream_text = f"Link: {message.message.text}\n"
text += stream_text text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE: elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage) assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name variable_name = message.message.variable_name
variable_value = message.message.variable_value variable_value = message.message.variable_value
if message.message.stream: if message.message.stream:
if not isinstance(variable_value, str): if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.") raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables: if variable_name not in variables:
variables[variable_name] = "" variables[variable_name] = ""
variables[variable_name] += variable_value variables[variable_name] += variable_value
yield RunStreamChunkEvent( yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
) )
else: else:
variables[variable_name] = variable_value variables[variable_name] = variable_value
@ -319,7 +319,7 @@ class ToolNode(BaseNode[ToolNodeData]):
dict_metadata = dict(message.message.metadata) dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"): if dict_metadata.get("provider"):
manager = PluginInstaller() manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id) plugins = manager.list_plugins(tenant_id)
try: try:
current_plugin = next( current_plugin = next(
plugin plugin
@ -334,8 +334,8 @@ class ToolNode(BaseNode[ToolNodeData]):
builtin_tool = next( builtin_tool = next(
provider provider
for provider in BuiltinToolManageService.list_builtin_tools( for provider in BuiltinToolManageService.list_builtin_tools(
self.user_id, user_id,
self.tenant_id, tenant_id,
) )
if provider.name == dict_metadata["provider"] if provider.name == dict_metadata["provider"]
) )
@ -347,57 +347,10 @@ class ToolNode(BaseNode[ToolNodeData]):
dict_metadata["icon"] = icon dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=self.id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=self.node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage)
yield RunRetrieverResourceEvent(
retriever_resources=message.message.retriever_resources,
context=message.message.context,
)
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process # Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = [] json_output: list[dict[str, Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict] # Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json: if json:
json_output.extend(json) json_output.extend(json)
@ -409,12 +362,9 @@ class ToolNode(BaseNode[ToolNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={ metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
}, },
inputs=parameters_for_log, inputs=parameters_for_log,
llm_usage=llm_usage,
) )
) )
@ -424,7 +374,7 @@ class ToolNode(BaseNode[ToolNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: ToolNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping
@ -433,9 +383,12 @@ class ToolNode(BaseNode[ToolNodeData]):
:param node_data: node data :param node_data: node data
:return: :return:
""" """
# Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data)
result = {} result = {}
for parameter_name in node_data.tool_parameters: for parameter_name in typed_node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name] input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed": if input.type == "mixed":
assert isinstance(input.value, str) assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors() selectors = VariableTemplateParser(input.value).extract_variable_selectors()
@ -449,3 +402,29 @@ class ToolNode(BaseNode[ToolNodeData]):
result = {node_id + "." + key: value for key, value in result.items()} result = {node_id + "." + key: value for key, value in result.items()}
return result return result
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

@ -1,17 +1,41 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional
from core.variables.segments import Segment from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): class VariableAggregatorNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR _node_type = NodeType.VARIABLE_AGGREGATOR
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = VariableAssignerNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -21,8 +45,8 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
outputs: dict[str, Segment | Mapping[str, Segment]] = {} outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {} inputs = {}
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables: for selector in self._node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None: if variable is not None:
outputs = {"output": variable} outputs = {"output": variable}
@ -30,7 +54,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
inputs = {".".join(selector[1:]): variable.to_object()} inputs = {".".join(selector[1:]): variable.to_object()}
break break
else: else:
for group in self.node_data.advanced_settings.groups: for group in self._node_data.advanced_settings.groups:
for selector in group.variables: for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)

@ -7,7 +7,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory from factories import variable_factory
@ -22,11 +23,33 @@ if TYPE_CHECKING:
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode[VariableAssignerData]): class VariableAssignerNode(BaseNode):
_node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
_node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = VariableAssignerData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def __init__( def __init__(
self, self,
id: str, id: str,
@ -59,36 +82,39 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: VariableAssignerData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerData.model_validate(node_data)
mapping = {} mapping = {}
assigned_variable_node_id = node_data.assigned_variable_selector[0] assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(node_data.assigned_variable_selector) selector_key = ".".join(typed_node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#" key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.assigned_variable_selector mapping[key] = typed_node_data.assigned_variable_selector
selector_key = ".".join(node_data.input_variable_selector) selector_key = ".".join(typed_node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#" key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.input_variable_selector mapping[key] = typed_node_data.input_variable_selector
return mapping return mapping
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
assigned_variable_selector = self.node_data.assigned_variable_selector assigned_variable_selector = self._node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable): if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found") raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode: match self._node_data.write_mode:
case WriteMode.OVER_WRITE: case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value: if not income_value:
raise VariableOperatorNodeError("input value not found") raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value}) updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND: case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value: if not income_value:
raise VariableOperatorNodeError("input value not found") raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value] updated_value = original_variable.value + [income_value.value]
@ -101,7 +127,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _: case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
# Over write the variable. # Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

@ -1,6 +1,6 @@
import json import json
from collections.abc import Callable, Mapping, MutableMapping, Sequence from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, TypeAlias, cast from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable from core.variables import SegmentType, Variable
@ -10,7 +10,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
@ -28,8 +29,6 @@ from .exc import (
VariableNotFoundError, VariableNotFoundError,
) )
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0] selector_node_id = item.variable_selector[0]
@ -54,10 +53,32 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
mapping[key] = selector mapping[key] = selector
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): class VariableAssignerNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _conv_var_updater_factory(self) -> ConversationVariableUpdater: def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory() return conversation_variable_updater_factory()
@ -71,22 +92,25 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: VariableAssignerNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerNodeData.model_validate(node_data)
var_mapping: dict[str, Sequence[str]] = {} var_mapping: dict[str, Sequence[str]] = {}
for item in node_data.items: for item in typed_node_data.items:
_target_mapping_from_item(var_mapping, node_id, item) _target_mapping_from_item(var_mapping, node_id, item)
_source_mapping_from_item(var_mapping, node_id, item) _source_mapping_from_item(var_mapping, node_id, item)
return var_mapping return var_mapping
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump() inputs = self._node_data.model_dump()
process_data: dict[str, Any] = {} process_data: dict[str, Any] = {}
# NOTE: This node has no outputs # NOTE: This node has no outputs
updated_variable_selectors: list[Sequence[str]] = [] updated_variable_selectors: list[Sequence[str]] = []
try: try:
for item in self.node_data.items: for item in self._node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part # ==================== Validation Part

@ -5,7 +5,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from configs import dify_config from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File from core.file.models import File
from core.workflow.callbacks import WorkflowCallback from core.workflow.callbacks import WorkflowCallback
@ -146,7 +146,7 @@ class WorkflowEntry:
graph = Graph.init(graph_config=workflow.graph_dict) graph = Graph.init(graph_config=workflow.graph_dict)
# init workflow run state # init workflow run state
node_instance = node_cls( node = node_cls(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
config=node_config, config=node_config,
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
@ -190,17 +190,11 @@ class WorkflowEntry:
try: try:
# run node # run node
generator = node_instance.run() generator = node.run()
except Exception as e: except Exception as e:
logger.exception( logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}")
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
workflow.id, return node, generator
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
return node_instance, generator
@classmethod @classmethod
def run_free_node( def run_free_node(
@ -262,7 +256,7 @@ class WorkflowEntry:
node_cls = cast(type[BaseNode], node_cls) node_cls = cast(type[BaseNode], node_cls)
# init workflow run state # init workflow run state
node_instance: BaseNode = node_cls( node: BaseNode = node_cls(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
config=node_config, config=node_config,
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
@ -297,17 +291,12 @@ class WorkflowEntry:
) )
# run node # run node
generator = node_instance.run() generator = node.run()
return node_instance, generator return node, generator
except Exception as e: except Exception as e:
logger.exception( logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}")
"error while running node_instance, node_id=%s, type=%s, version=%s", raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@staticmethod @staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:

@ -465,10 +465,10 @@ class WorkflowService:
node_id: str, node_id: str,
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
try: try:
node_instance, generator = invoke_node_fn() node, node_events = invoke_node_fn()
node_run_result: NodeRunResult | None = None node_run_result: NodeRunResult | None = None
for event in generator: for event in node_events:
if isinstance(event, RunCompletedEvent): if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result node_run_result = event.run_result
@ -479,18 +479,18 @@ class WorkflowService:
if not node_run_result: if not node_run_result:
raise ValueError("Node run failed with no run result") raise ValueError("Node run failed with no run result")
# single step debug mode error handling return # single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error:
node_error_args: dict[str, Any] = { node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION, "status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error, "error": node_run_result.error,
"inputs": node_run_result.inputs, "inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy}, "metadata": {"error_strategy": node.error_strategy},
} }
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult( node_run_result = NodeRunResult(
**node_error_args, **node_error_args,
outputs={ outputs={
**node_instance.node_data.default_value_dict, **node.default_value_dict,
"error_message": node_run_result.error, "error_message": node_run_result.error,
"error_type": node_run_result.error_type, "error_type": node_run_result.error_type,
}, },
@ -509,10 +509,10 @@ class WorkflowService:
) )
error = node_run_result.error if not run_succeeded else None error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e: except WorkflowNodeRunFailedError as e:
node_instance = e.node_instance node = e._node
run_succeeded = False run_succeeded = False
node_run_result = None node_run_result = None
error = e.error error = e._error
# Create a NodeExecution domain model # Create a NodeExecution domain model
node_execution = WorkflowNodeExecution( node_execution = WorkflowNodeExecution(
@ -520,8 +520,8 @@ class WorkflowService:
workflow_id="", # This is a single-step execution, so no workflow ID workflow_id="", # This is a single-step execution, so no workflow ID
index=1, index=1,
node_id=node_id, node_id=node_id,
node_type=node_instance.node_type, node_type=node.type_,
title=node_instance.node_data.title, title=node.title,
elapsed_time=time.perf_counter() - start_at, elapsed_time=time.perf_counter() - start_at,
created_at=datetime.now(UTC).replace(tzinfo=None), created_at=datetime.now(UTC).replace(tzinfo=None),
finished_at=datetime.now(UTC).replace(tzinfo=None), finished_at=datetime.now(UTC).replace(tzinfo=None),

@ -15,7 +15,7 @@ def get_mocked_fetch_model_config(
mode: str, mode: str,
credentials: dict, credentials: dict,
): ):
model_provider_factory = ModelProviderFactory(tenant_id="test_tenant") model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
provider_model_bundle = ProviderModelBundle( provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration( configuration=ProviderConfiguration(

@ -66,6 +66,10 @@ def init_code_node(code_config: dict):
config=code_config, config=code_config,
) )
# Initialize node data
if "data" in code_config:
node.init_node_data(code_config["data"])
return node return node
@ -234,10 +238,10 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
} }
node.node_data = cast(CodeNodeData, node.node_data) node._node_data = cast(CodeNodeData, node._node_data)
# validate # validate
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
# construct result # construct result
result = { result = {
@ -250,7 +254,7 @@ def test_execute_code_output_validator_depth():
# validate # validate
with pytest.raises(ValueError): with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
# construct result # construct result
result = { result = {
@ -263,7 +267,7 @@ def test_execute_code_output_validator_depth():
# validate # validate
with pytest.raises(ValueError): with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
# construct result # construct result
result = { result = {
@ -276,7 +280,7 @@ def test_execute_code_output_validator_depth():
# validate # validate
with pytest.raises(ValueError): with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
def test_execute_code_output_object_list(): def test_execute_code_output_object_list():
@ -330,10 +334,10 @@ def test_execute_code_output_object_list():
] ]
} }
node.node_data = cast(CodeNodeData, node.node_data) node._node_data = cast(CodeNodeData, node._node_data)
# validate # validate
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
# construct result # construct result
result = { result = {
@ -353,7 +357,7 @@ def test_execute_code_output_object_list():
# validate # validate
with pytest.raises(ValueError): with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs) node._transform_result(result, node._node_data.outputs)
def test_execute_code_scientific_notation(): def test_execute_code_scientific_notation():

@ -52,7 +52,7 @@ def init_http_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1) variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2) variable_pool.add(["a", "b123", "args2"], 2)
return HttpRequestNode( node = HttpRequestNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
@ -60,6 +60,12 @@ def init_http_node(config: dict):
config=config, config=config,
) )
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_get(setup_http_mock): def test_get(setup_http_mock):

@ -2,15 +2,10 @@ import json
import time import time
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from decimal import Decimal
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
@ -24,8 +19,6 @@ from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
"""FOR MOCK FIXTURES, DO NOT REMOVE""" """FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
def init_llm_node(config: dict) -> LLMNode: def init_llm_node(config: dict) -> LLMNode:
@ -84,10 +77,14 @@ def init_llm_node(config: dict) -> LLMNode:
config=config, config=config,
) )
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node return node
def test_execute_llm(flask_req_ctx): def test_execute_llm():
node = init_llm_node( node = init_llm_node(
config={ config={
"id": "llm", "id": "llm",
@ -95,7 +92,7 @@ def test_execute_llm(flask_req_ctx):
"title": "123", "title": "123",
"type": "llm", "type": "llm",
"model": { "model": {
"provider": "langgenius/openai/openai", "provider": "openai",
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"mode": "chat", "mode": "chat",
"completion_params": {}, "completion_params": {},
@ -114,53 +111,62 @@ def test_execute_llm(flask_req_ctx):
}, },
) )
# Create a proper LLM result with real entities db.session.close = MagicMock()
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes # Mock the _fetch_model_config to avoid database calls
mock_model_config = MagicMock() def mock_fetch_model_config(**_kwargs):
mock_model_config.mode = "chat" from decimal import Decimal
mock_model_config.provider = "langgenius/openai/openai" from unittest.mock import MagicMock
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response from mock")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls # Mock fetch_prompt_messages to avoid database calls
def mock_get_model_instance(_self, **kwargs): def mock_fetch_prompt_messages_1(**_kwargs):
return mock_model_instance from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
with ( with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
): ):
# execute node # execute node
result = node._run() result = node._run()
@ -168,6 +174,9 @@ def test_execute_llm(flask_req_ctx):
for item in result: for item in result:
if isinstance(item, RunCompletedEvent): if isinstance(item, RunCompletedEvent):
if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
print(f"Error: {item.run_result.error}")
print(f"Error type: {item.run_result.error_type}")
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None assert item.run_result.process_data is not None
assert item.run_result.outputs is not None assert item.run_result.outputs is not None
@ -175,8 +184,7 @@ def test_execute_llm(flask_req_ctx):
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) def test_execute_llm_with_jinja2():
def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
""" """
Test execute LLM node with jinja2 Test execute LLM node with jinja2
""" """
@ -217,53 +225,60 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method # Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model): def mock_fetch_model_config(**_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls # Mock fetch_prompt_messages to avoid database calls
def mock_get_model_instance(_self, **kwargs): def mock_fetch_prompt_messages_2(**_kwargs):
return mock_model_instance from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
with ( with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
): ):
# execute node # execute node
result = node._run() result = node._run()

@ -74,13 +74,15 @@ def init_parameter_extractor_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1) variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2) variable_pool.add(["a", "b123", "args2"], 2)
return ParameterExtractorNode( node = ParameterExtractorNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config, config=config,
) )
node.init_node_data(config.get("data", {}))
return node
def test_function_calling_parameter_extractor(setup_model_mock): def test_function_calling_parameter_extractor(setup_model_mock):

@ -76,6 +76,7 @@ def test_execute_code(setup_code_executor_mock):
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config, config=config,
) )
node.init_node_data(config.get("data", {}))
# execute node # execute node
result = node._run() result = node._run()

@ -50,13 +50,15 @@ def init_tool_node(config: dict):
conversation_variables=[], conversation_variables=[],
) )
return ToolNode( node = ToolNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config, config=config,
) )
node.init_node_data(config.get("data", {}))
return node
def test_tool_variable_invoke(): def test_tool_variable_invoke():

@ -58,21 +58,26 @@ def test_execute_answer():
pool.add(["start", "weather"], "sunny") pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.") pool.add(["llm", "text"], "You are a helpful AI.")
node_config = {
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
}
node = AnswerNode( node = AnswerNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()

@ -57,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch):
), ),
), ),
) )
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode( node = HttpRequestNode(
id="1", id="1",
config={ config=node_config,
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
tenant_id="1", tenant_id="1",
app_id="1", app_id="1",
@ -90,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch):
start_at=0, start_at=0,
), ),
) )
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr( monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download", "core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test", lambda *args, **kwargs: b"test",
@ -145,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch):
), ),
), ),
) )
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode( node = HttpRequestNode(
id="1", id="1",
config={ config=node_config,
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
tenant_id="1", tenant_id="1",
app_id="1", app_id="1",
@ -178,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch):
start_at=0, start_at=0,
), ),
) )
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr( monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download", "core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test", lambda *args, **kwargs: b"test",
@ -257,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
), ),
) )
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode( node = HttpRequestNode(
id="1", id="1",
config={ config=node_config,
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
tenant_id="1", tenant_id="1",
app_id="1", app_id="1",
@ -291,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
), ),
) )
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr( monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download", "core.workflow.nodes.http_request.executor.file_manager.download",
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data", lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",

@ -162,25 +162,30 @@ def test_run():
) )
pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}
iteration_node = IterationNode( iteration_node = IterationNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={ config=node_config,
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
},
) )
# Initialize node data
iteration_node.init_node_data(node_config["data"])
def tt_generator(self): def tt_generator(self):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -379,25 +384,30 @@ def test_run_parallel():
) )
pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}
iteration_node = IterationNode( iteration_node = IterationNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={ config=node_config,
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
},
) )
# Initialize node data
iteration_node.init_node_data(node_config["data"])
def tt_generator(self): def tt_generator(self):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -595,45 +605,55 @@ def test_iteration_run_in_parallel_mode():
) )
pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
parallel_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
}
parallel_iteration_node = IterationNode( parallel_iteration_node = IterationNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={ config=parallel_node_config,
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
},
) )
# Initialize node data
parallel_iteration_node.init_node_data(parallel_node_config["data"])
sequential_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
}
sequential_iteration_node = IterationNode( sequential_iteration_node = IterationNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={ config=sequential_node_config,
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
},
) )
# Initialize node data
sequential_iteration_node.init_node_data(sequential_node_config["data"])
def tt_generator(self): def tt_generator(self):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -645,8 +665,8 @@ def test_iteration_run_in_parallel_mode():
# execute node # execute node
parallel_result = parallel_iteration_node._run() parallel_result = parallel_iteration_node._run()
sequential_result = sequential_iteration_node._run() sequential_result = sequential_iteration_node._run()
assert parallel_iteration_node.node_data.parallel_nums == 10 assert parallel_iteration_node._node_data.parallel_nums == 10
assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
count = 0 count = 0
parallel_arr = [] parallel_arr = []
sequential_arr = [] sequential_arr = []
@ -818,26 +838,31 @@ def test_iteration_run_error_handle():
environment_variables=[], environment_variables=[],
) )
pool.add(["pe", "list_output"], ["1", "1"]) pool.add(["pe", "list_output"], ["1", "1"])
error_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
"is_parallel": True,
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
},
"id": "iteration-1",
}
iteration_node = IterationNode( iteration_node = IterationNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={ config=error_node_config,
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
"is_parallel": True,
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
},
"id": "iteration-1",
},
) )
# Initialize node data
iteration_node.init_node_data(error_node_config["data"])
# execute continue on error node # execute continue on error node
result = iteration_node._run() result = iteration_node._run()
result_arr = [] result_arr = []
@ -851,7 +876,7 @@ def test_iteration_run_error_handle():
assert count == 14 assert count == 14
# execute remove abnormal output # execute remove abnormal output
iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
result = iteration_node._run() result = iteration_node._run()
count = 0 count = 0
for item in result: for item in result:

@ -119,17 +119,20 @@ def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode: ) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver) mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
}
node = LLMNode( node = LLMNode(
id="1", id="1",
config={ config=node_config,
"id": "1",
"data": llm_node_data.model_dump(),
},
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver, llm_file_saver=mock_file_saver,
) )
# Initialize node data
node.init_node_data(node_config["data"])
return node return node
@ -488,7 +491,7 @@ def test_handle_list_messages_basic(llm_node):
variable_pool = llm_node.graph_runtime_state.variable_pool variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
result = llm_node._handle_list_messages( result = llm_node.handle_list_messages(
messages=messages, messages=messages,
context=context, context=context,
jinja2_variables=jinja2_variables, jinja2_variables=jinja2_variables,
@ -506,17 +509,20 @@ def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, LLMFileSaver]: ) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
}
node = LLMNode( node = LLMNode(
id="1", id="1",
config={ config=node_config,
"id": "1",
"data": llm_node_data.model_dump(),
},
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph=graph, graph=graph,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver, llm_file_saver=mock_file_saver,
) )
# Initialize node data
node.init_node_data(node_config["data"])
return node, mock_file_saver return node, mock_file_saver
@ -540,7 +546,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9, size=9,
) )
mock_file_saver.save_binary_string.return_value = mock_file mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content) file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
# Manually append to _file_outputs since the static method doesn't do it
llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file] assert llm_node._file_outputs == [mock_file]
assert file == mock_file assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with( mock_file_saver.save_binary_string.assert_called_once_with(
@ -566,7 +577,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9, size=9,
) )
mock_file_saver.save_remote_url.return_value = mock_file mock_file_saver.save_remote_url.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content) file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
# Manually append to _file_outputs since the static method doesn't do it
llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file] assert llm_node._file_outputs == [mock_file]
assert file == mock_file assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE) mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
@ -582,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
class TestSaveMultimodalOutputAndConvertResultToMarkdown: class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal): def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world") gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents="hello world", file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["hello world"] assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
@ -590,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_text_prompt_message_content(self, llm_node_for_multimodal): def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")] contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[]
) )
assert list(gen) == ["hello world"] assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
@ -616,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
) )
mock_file_saver.save_binary_string.return_value = mock_saved_file mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[ contents=[
ImagePromptMessageContent( ImagePromptMessageContent(
format="png", format="png",
base64_data=image_b64_data, base64_data=image_b64_data,
mime_type="image/png", mime_type="image/png",
) )
] ],
file_saver=mock_file_saver,
file_outputs=llm_node._file_outputs,
) )
yielded_strs = list(gen) yielded_strs = list(gen)
assert len(yielded_strs) == 1 assert len(yielded_strs) == 1
@ -645,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_unknown_content_type(self, llm_node_for_multimodal): def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"])) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"] assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal): def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])]) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"] assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal): def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None) gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=None, file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == [] assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called() mock_file_saver.save_remote_url.assert_not_called()

@ -61,21 +61,26 @@ def test_execute_answer():
variable_pool.add(["start", "weather"], "sunny") variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.") variable_pool.add(["llm", "text"], "You are a helpful AI.")
node_config = {
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
}
node = AnswerNode( node = AnswerNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()

@ -27,13 +27,17 @@ def document_extractor_node():
title="Test Document Extractor", title="Test Document Extractor",
variable_selector=["node_id", "variable_name"], variable_selector=["node_id", "variable_name"],
) )
return DocumentExtractorNode( node_config = {"id": "test_node_id", "data": node_data.model_dump()}
node = DocumentExtractorNode(
id="test_node_id", id="test_node_id",
config={"id": "test_node_id", "data": node_data.model_dump()}, config=node_config,
graph_init_params=Mock(), graph_init_params=Mock(),
graph=Mock(), graph=Mock(),
graph_runtime_state=Mock(), graph_runtime_state=Mock(),
) )
# Initialize node data
node.init_node_data(node_config["data"])
return node
@pytest.fixture @pytest.fixture

@ -57,57 +57,62 @@ def test_execute_if_else_result_true():
pool.add(["start", "null"], None) pool.add(["start", "null"], None)
pool.add(["start", "not_null"], "1212") pool.add(["start", "not_null"], "1212")
node_config = {
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
}
node = IfElseNode( node = IfElseNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
},
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()
@ -162,33 +167,38 @@ def test_execute_if_else_result_false():
pool.add(["start", "array_contains"], ["1ab", "def"]) pool.add(["start", "array_contains"], ["1ab", "def"])
pool.add(["start", "array_not_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ab", "def"])
node_config = {
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
}
node = IfElseNode( node = IfElseNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
},
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()
@ -228,17 +238,22 @@ def test_array_file_contains_file_name():
], ],
) )
node_config = {
"id": "if-else",
"data": node_data.model_dump(),
}
node = IfElseNode( node = IfElseNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=Mock(), graph_init_params=Mock(),
graph=Mock(), graph=Mock(),
graph_runtime_state=Mock(), graph_runtime_state=Mock(),
config={ config=node_config,
"id": "if-else",
"data": node_data.model_dump(),
},
) )
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[ value=[
File( File(

@ -33,16 +33,19 @@ def list_operator_node():
"title": "Test Title", "title": "Test Title",
} }
node_data = ListOperatorNodeData(**config) node_data = ListOperatorNodeData(**config)
node_config = {
"id": "test_node_id",
"data": node_data.model_dump(),
}
node = ListOperatorNode( node = ListOperatorNode(
id="test_node_id", id="test_node_id",
config={ config=node_config,
"id": "test_node_id",
"data": node_data.model_dump(),
},
graph_init_params=MagicMock(), graph_init_params=MagicMock(),
graph=MagicMock(), graph=MagicMock(),
graph_runtime_state=MagicMock(), graph_runtime_state=MagicMock(),
) )
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state = MagicMock() node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock() node.graph_runtime_state.variable_pool = MagicMock()
return node return node

@ -38,12 +38,13 @@ def _create_tool_node():
system_variables=SystemVariable.empty(), system_variables=SystemVariable.empty(),
user_inputs={}, user_inputs={},
) )
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = ToolNode( node = ToolNode(
id="1", id="1",
config={ config=node_config,
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
tenant_id="1", tenant_id="1",
app_id="1", app_id="1",
@ -71,6 +72,8 @@ def _create_tool_node():
start_at=0, start_at=0,
), ),
) )
# Initialize node data
node.init_node_data(node_config["data"])
return node return node

@ -82,23 +82,28 @@ def test_overwrite_string_variable():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
node = VariableAssignerNode( node = VariableAssignerNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory, conv_var_updater_factory=mock_conv_var_updater_factory,
) )
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run()) list(node.run())
expected_var = StringVariable( expected_var = StringVariable(
id=conversation_variable.id, id=conversation_variable.id,
@ -178,23 +183,28 @@ def test_append_variable_to_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
node = VariableAssignerNode( node = VariableAssignerNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory, conv_var_updater_factory=mock_conv_var_updater_factory,
) )
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run()) list(node.run())
expected_value = list(conversation_variable.value) expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value) expected_value.append(input_variable.value)
@ -265,23 +275,28 @@ def test_clear_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
}
node = VariableAssignerNode( node = VariableAssignerNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
},
conv_var_updater_factory=mock_conv_var_updater_factory, conv_var_updater_factory=mock_conv_var_updater_factory,
) )
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run()) list(node.run())
expected_var = ArrayStringVariable( expected_var = ArrayStringVariable(
id=conversation_variable.id, id=conversation_variable.id,

@ -115,28 +115,33 @@ def test_remove_first_from_array():
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],
) )
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
}
node = VariableAssignerNode( node = VariableAssignerNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
},
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment # Skip the mock assertion since we're in a test environment
# Print the variable before running # Print the variable before running
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}") print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
@ -202,28 +207,33 @@ def test_remove_last_from_array():
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],
) )
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
}
node = VariableAssignerNode( node = VariableAssignerNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
},
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment # Skip the mock assertion since we're in a test environment
list(node.run()) list(node.run())
@ -281,28 +291,33 @@ def test_remove_first_from_empty_array():
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],
) )
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
}
node = VariableAssignerNode( node = VariableAssignerNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
},
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment # Skip the mock assertion since we're in a test environment
list(node.run()) list(node.run())
@ -360,28 +375,33 @@ def test_remove_last_from_empty_array():
conversation_variables=[conversation_variable], conversation_variables=[conversation_variable],
) )
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
}
node = VariableAssignerNode( node = VariableAssignerNode(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={ config=node_config,
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
},
) )
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment # Skip the mock assertion since we're in a test environment
list(node.run()) list(node.run())

Loading…
Cancel
Save