From bd3f3e7b26d703ecb9eb11b992b40cf8bec3b3a1 Mon Sep 17 00:00:00 2001 From: ziqiang <1694392889@qq.com> Date: Mon, 23 Jun 2025 19:17:10 +0800 Subject: [PATCH] =?UTF-8?q?feat(workflow):=20=E6=9B=B4=E6=96=B0=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E9=98=9F=E5=88=97=E8=8A=82=E7=82=B9=20(MqNode)=20?= =?UTF-8?q?=E7=9A=84=E9=85=8D=E7=BD=AE=EF=BC=8C=E6=B7=BB=E5=8A=A0=E5=88=B0?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86=E8=8A=82=E7=82=B9=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E5=88=97=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/workflow/nodes/enums.py | 3 +- api/core/workflow/nodes/mq/__init__.py | 5 + api/core/workflow/nodes/mq/entities.py | 22 +++ api/core/workflow/nodes/mq/mq_node.py | 49 +++++ api/core/workflow/nodes/mq/rabbitmq_client.py | 175 ++++++++++++++++++ api/core/workflow/nodes/node_mapping.py | 5 + web/app/components/workflow/constants.ts | 3 +- .../components/workflow/nodes/mq/default.ts | 32 ++++ web/next.config.js | 5 +- 9 files changed, 296 insertions(+), 3 deletions(-) create mode 100644 api/core/workflow/nodes/mq/__init__.py create mode 100644 api/core/workflow/nodes/mq/entities.py create mode 100644 api/core/workflow/nodes/mq/mq_node.py create mode 100644 api/core/workflow/nodes/mq/rabbitmq_client.py create mode 100644 web/app/components/workflow/nodes/mq/default.ts diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 73b43eeaf7..7b6fc5d80a 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -12,6 +12,7 @@ class NodeType(StrEnum): TEMPLATE_TRANSFORM = "template-transform" QUESTION_CLASSIFIER = "question-classifier" HTTP_REQUEST = "http-request" + MqNode = "MqNode" TOOL = "tool" VARIABLE_AGGREGATOR = "variable-aggregator" LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. @@ -37,5 +38,5 @@ class FailBranchSourceHandle(StrEnum): SUCCESS = "success-branch" -CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] +CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST, NodeType.MqNode] RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/mq/__init__.py b/api/core/workflow/nodes/mq/__init__.py new file mode 100644 index 0000000000..e6ccde43c5 --- /dev/null +++ b/api/core/workflow/nodes/mq/__init__.py @@ -0,0 +1,5 @@ +from .entities import MqNodeData +from .mq_node import MqNode +from .rabbitmq_client import RabbitMQClient + +__all__ = ["MqNode", "MqNodeData", "RabbitMQClient"] \ No newline at end of file diff --git a/api/core/workflow/nodes/mq/entities.py b/api/core/workflow/nodes/mq/entities.py new file mode 100644 index 0000000000..9b76dba14a --- /dev/null +++ b/api/core/workflow/nodes/mq/entities.py @@ -0,0 +1,22 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.workflow.nodes.base import BaseNodeData + + +class MqNodeData(BaseNodeData): + """ + Mq Node Data. + """ + + class Case(BaseModel): + """ + Case entity representing a single logical condition group + """ + + channel: str + message: str + + channel: Optional[str] = "abc" + message: Optional[str] = "message" diff --git a/api/core/workflow/nodes/mq/mq_node.py b/api/core/workflow/nodes/mq/mq_node.py new file mode 100644 index 0000000000..49b9512b2b --- /dev/null +++ b/api/core/workflow/nodes/mq/mq_node.py @@ -0,0 +1,49 @@ + + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.mq.entities import MqNodeData +from core.workflow.nodes.mq.rabbitmq_client import RabbitMQClient + +# 创建全局RabbitMQ客户端实例 +rabbitmq_client = RabbitMQClient("dify_node") + +class MqNode(BaseNode[MqNodeData]): + _node_data_cls = MqNodeData + _node_type = NodeType.MqNode + + def _run(self) -> NodeRunResult: + """ + Run mq node + :return: + """ + print("go go go execute execute execute") + node_inputs: dict[str, list] = {"conditions": []} + + process_data: dict[str, list] = {"condition_results": []} + + input_conditions = [] + final_result = False + + try: + rabbitmq_client.publish_json({ + "action": "downloadImage", + "msg": '你好' + }) + except Exception: + print("err") + pass + outputs = {"result": True, "message": 'xxx'} + + data = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + edge_source_handle="false", # Use case ID or 'default' + outputs=outputs, + ) + + return data + diff --git a/api/core/workflow/nodes/mq/rabbitmq_client.py b/api/core/workflow/nodes/mq/rabbitmq_client.py new file mode 100644 index 0000000000..0414c329f3 --- /dev/null +++ b/api/core/workflow/nodes/mq/rabbitmq_client.py @@ -0,0 +1,175 @@ +import json +import logging +import threading +import time +from typing import Any, Optional + +import pika +from pika.exceptions import AMQPChannelError, AMQPConnectionError, StreamLostError + +logger = logging.getLogger(__name__) + +class RabbitMQClient: + """ + RabbitMQ客户端,实现自动重连和连接管理 + """ + + def __init__(self, queue_name: str): + self.queue_name = queue_name + self._connection_params = self._get_connection_params() + self._stopping = False + self._lock = threading.Lock() + self._local = threading.local() + self._reconnect_delay = 1 # 初始重连延迟(秒) + self._max_reconnect_delay = 30 # 最大重连延迟(秒) + + def _get_connection_params(self) -> pika.ConnectionParameters: + """获取RabbitMQ连接参数""" + RABBITMQ_CONFIG = { + 'host': '127.0.0.1', + 'port': 5672, + 'username': 'apitable', + 'password': 'apitable@com', + 'virtual_host': '/' + } + print('virtual_host:' + RABBITMQ_CONFIG['virtual_host']) + return pika.ConnectionParameters( + host=RABBITMQ_CONFIG['host'], + port=RABBITMQ_CONFIG['port'], + virtual_host=RABBITMQ_CONFIG['virtual_host'], + credentials=pika.PlainCredentials( + username=RABBITMQ_CONFIG['username'], + password=RABBITMQ_CONFIG['password'] + ), + heartbeat=30, # 心跳超时时间 + connection_attempts=3, # 连接尝试次数 + retry_delay=5, # 重试延迟 + socket_timeout=10, # socket超时 + blocked_connection_timeout=300, # 阻塞连接超时 + client_properties={'connection_name': f'spider_client_{threading.get_ident()}'} # 添加线程标识 + ) + + def _get_connection(self): + """获取当前线程的连接""" + if not hasattr(self._local, 'connection') or not self._local.connection or self._local.connection.is_closed: + self._local.connection = pika.BlockingConnection(self._connection_params) + return self._local.connection + + def _get_channel(self): + """获取当前线程的通道""" + if not hasattr(self._local, 'channel') or not self._local.channel or self._local.channel.is_closed: + connection = self._get_connection() + self._local.channel = connection.channel() + self._local.channel.queue_declare(queue=self.queue_name, durable=True) + self._local.channel.confirm_delivery() + return self._local.channel + + def _ensure_connection(self) -> bool: + """ + 确保RabbitMQ连接和通道可用 + + Returns: + bool: 连接是否成功 + """ + try: + self._get_channel() + return True + except Exception as e: + # logger.exception(f"RabbitMQ连接失败: {str(e)}") + # 使用指数退避策略 + time.sleep(self._reconnect_delay) + self._reconnect_delay = min(self._reconnect_delay * 2, self._max_reconnect_delay) + return False + + def publish_json(self, message: dict[str, Any], max_retries: int = 3) -> bool: + """ + 发布任意JSON消息到队列 + + Args: + message: 要发送的消息字典 + max_retries: 最大重试次数 + + Returns: + bool: 发送是否成功 + """ + retries = 0 + while retries < max_retries and not self._stopping: + try: + channel = self._get_channel() + + # 发布消息并等待确认 + channel.basic_publish( + exchange='', + routing_key=self.queue_name, + body=json.dumps(message), + properties=pika.BasicProperties( + delivery_mode=2, # 消息持久化 + content_type='application/json' + ), + mandatory=True # 确保消息能够被路由 + ) + logger.info(f"消息发送成功: {message}") + self._reconnect_delay = 1 # 重置重连延迟 + return True + + except (AMQPConnectionError, AMQPChannelError, StreamLostError) as e: + # logger.exception(f"发送消息失败 (尝试 {retries + 1}/{max_retries}): {str(e)}") + # 清除当前线程的连接和通道 + if hasattr(self._local, 'channel'): + delattr(self._local, 'channel') + if hasattr(self._local, 'connection'): + delattr(self._local, 'connection') + retries += 1 + if retries < max_retries: + time.sleep(self._reconnect_delay) + except Exception as e: + # logger.exception(f"发送消息时发生未知错误: {str(e)}") + return False + + return False + + def publish_message(self, task_id: str, action: str, status: Optional[str] = None, + extra: Optional[str] = None, reason: Optional[str] = None) -> bool: + """ + 发布任务状态更新消息 + + Args: + task_id: 任务ID + action: 动作类型 + status: 状态(FINISHED或EXCEPTION) + extra: 额外信息(字符串) + reason: 原因(异常时使用) + + Returns: + bool: 是否发送成功 + """ + message = { + "taskId": task_id, + "action": action + } + + if status: + message["status"] = status + if extra: + message["extra"] = extra # 直接设置字符串 + if reason: + message["reason"] = reason[:500] if reason else None # 限制reason长度 + + return self.publish_json(message) + + def close(self) -> None: + """关闭连接""" + self._stopping = True + if hasattr(self._local, 'channel') and self._local.channel and not self._local.channel.is_closed: + try: + self._local.channel.close() + except Exception as e: + logger.info(f"关闭通道时发生错误: {str(e)}") + + if hasattr(self._local, 'connection') and self._local.connection and not self._local.connection.is_closed: + try: + self._local.connection.close() + except Exception as e: + logger.info(f"关闭连接时发生错误: {str(e)}") + + logger.info("RabbitMQ连接已关闭") diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 1f1be59542..86e9b9ad65 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -14,6 +14,7 @@ from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode from core.workflow.nodes.list_operator import ListOperatorNode from core.workflow.nodes.llm import LLMNode from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode +from core.workflow.nodes.mq import MqNode from core.workflow.nodes.parameter_extractor import ParameterExtractorNode from core.workflow.nodes.question_classifier import QuestionClassifierNode from core.workflow.nodes.start import StartNode @@ -66,6 +67,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { LATEST_VERSION: HttpRequestNode, "1": HttpRequestNode, }, + NodeType.MqNode: { + LATEST_VERSION: MqNode, + "1": MqNode, + }, NodeType.TOOL: { LATEST_VERSION: ToolNode, "1": ToolNode, diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index 85079da6fc..a2a6cdf83e 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -9,6 +9,7 @@ import IfElseDefault from './nodes/if-else/default' import CodeDefault from './nodes/code/default' import TemplateTransformDefault from './nodes/template-transform/default' import HttpRequestDefault from './nodes/http/default' +import MqDefault from './nodes/mq/default' import ParameterExtractorDefault from './nodes/parameter-extractor/default' import ToolDefault from './nodes/tool/default' import VariableAssignerDefault from './nodes/variable-assigner/default' @@ -175,7 +176,7 @@ export const NODES_EXTRA_DATA: Record = { availableNextNodes: [], getAvailablePrevNodes: HttpRequestDefault.getAvailablePrevNodes, getAvailableNextNodes: HttpRequestDefault.getAvailableNextNodes, - checkValid: HttpRequestDefault.checkValid, + checkValid: MqDefault.checkValid, }, [BlockEnum.VariableAssigner]: { author: 'Dify', diff --git a/web/app/components/workflow/nodes/mq/default.ts b/web/app/components/workflow/nodes/mq/default.ts new file mode 100644 index 0000000000..dd5101ec89 --- /dev/null +++ b/web/app/components/workflow/nodes/mq/default.ts @@ -0,0 +1,32 @@ +import { BlockEnum } from '../../types' +import type { NodeDefault } from '../../types' +import { + ALL_CHAT_AVAILABLE_BLOCKS, + ALL_COMPLETION_AVAILABLE_BLOCKS, +} from '@/app/components/workflow/blocks' + +const nodeDefault: NodeDefault = { + defaultValue: { + channelName: '', + message: '', + }, + getAvailablePrevNodes(isChatMode: boolean) { + const nodes = isChatMode + ? ALL_CHAT_AVAILABLE_BLOCKS + : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End) + return nodes + }, + getAvailableNextNodes(isChatMode: boolean) { + const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS + return nodes + }, + checkValid(payload: any, t: any) { + const errorMessages = '' + return { + isValid: !errorMessages, + errorMessage: errorMessages, + } + }, +} + +export default nodeDefault diff --git a/web/next.config.js b/web/next.config.js index 9ce1b35644..27f0d0f68b 100644 --- a/web/next.config.js +++ b/web/next.config.js @@ -24,7 +24,10 @@ const nextConfig = { basePath, assetPrefix, webpack: (config, { dev, isServer }) => { - config.plugins.push(codeInspectorPlugin({ bundler: 'webpack' })) + config.plugins.push(codeInspectorPlugin({ + hideDomPathAttr: true, + bundler: 'webpack', +})) return config }, productionBrowserSourceMaps: false, // enable browser source map generation during the production build