diff --git a/.github/workflows/deploy-rag-dev.yml b/.github/workflows/deploy-rag-dev.yml new file mode 100644 index 0000000000..86265aad6d --- /dev/null +++ b/.github/workflows/deploy-rag-dev.yml @@ -0,0 +1,28 @@ +name: Deploy RAG Dev + +permissions: + contents: read + +on: + workflow_run: + workflows: ["Build and Push API & Web"] + branches: + - "deploy/rag-dev" + types: + - completed + +jobs: + deploy: + runs-on: ubuntu-latest + if: | + github.event.workflow_run.conclusion == 'success' && + github.event.workflow_run.head_branch == 'deploy/rag-dev' + steps: + - name: Deploy to server + uses: appleboy/ssh-action@v0.1.8 + with: + host: ${{ secrets.RAG_SSH_HOST }} + username: ${{ secrets.SSH_USER }} + key: ${{ secrets.SSH_PRIVATE_KEY }} + script: | + ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }} diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 23d402f914..93976bd6f5 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -1,6 +1,5 @@ import logging -import yaml from flask import request from flask_restful import Resource, reqparse from sqlalchemy.orm import Session diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 7b8adfe560..c97b3b1d92 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator import services from configs import dify_config from controllers.console import api @@ -453,7 +454,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource): raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() - result = rag_pipeline_service.run_datasource_workflow_node( + return helper.compact_generate_response(rag_pipeline_service.run_datasource_workflow_node( pipeline=pipeline, node_id=node_id, user_inputs=inputs, @@ -461,8 +462,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource): datasource_type=datasource_type, is_published=False ) - - return result + ) class RagPipelinePublishedNodeRunApi(Resource): diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index adcdcccf83..b9a0c1f150 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -15,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabelEnum +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum class DatasourceProviderType(enum.StrEnum): @@ -207,12 +207,6 @@ class DatasourceInvokeFrom(Enum): RAG_PIPELINE = "rag_pipeline" -class GetOnlineDocumentPagesRequest(BaseModel): - """ - Get online document pages request - """ - - class OnlineDocumentPage(BaseModel): """ Online document page @@ -237,7 +231,7 @@ class OnlineDocumentInfo(BaseModel): pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") -class GetOnlineDocumentPagesResponse(BaseModel): +class OnlineDocumentPagesMessage(BaseModel): """ Get online document pages response """ @@ -290,14 +284,42 @@ class WebSiteInfo(BaseModel): """ Website info """ - job_id: str = Field(..., description="The job id") - status: str = Field(..., description="The status of the job") + status: Optional[str] = Field(..., description="crawl job status") web_info_list: Optional[list[WebSiteInfoDetail]] = [] + total: Optional[int] = Field(default=0, description="The total number of websites") + completed: Optional[int] = Field(default=0, description="The number of completed websites") - -class GetWebsiteCrawlResponse(BaseModel): +class WebsiteCrawlMessage(BaseModel): """ Get website crawl response """ + result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) + +class DatasourceMessage(ToolInvokeMessage): + pass + + +class DatasourceInvokeMessage(ToolInvokeMessage): + """ + Datasource Invoke Message. + """ + + class WebsiteCrawlMessage(BaseModel): + """ + Website crawl message + """ + + job_id: str = Field(..., description="The job id") + status: str = Field(..., description="The status of the job") + web_info_list: Optional[list[WebSiteInfoDetail]] = [] + + class OnlineDocumentMessage(BaseModel): + """ + Online document message + """ - result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[]) + workspace_id: str = Field(..., description="The workspace id") + workspace_name: str = Field(..., description="The workspace name") + workspace_icon: str = Field(..., description="The workspace icon") + total: int = Field(..., description="The total number of documents") + pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index f94031656e..db73d9a64b 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,14 +1,14 @@ -from collections.abc import Mapping +from collections.abc import Generator, Mapping from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, + DatasourceInvokeMessage, DatasourceProviderType, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesResponse, + OnlineDocumentPagesMessage, ) from core.plugin.impl.datasource import PluginDatasourceManager @@ -38,7 +38,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetOnlineDocumentPagesResponse: + ) -> Generator[OnlineDocumentPagesMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_pages( @@ -56,7 +56,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: GetOnlineDocumentPageContentRequest, provider_type: str, - ) -> GetOnlineDocumentPageContentResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_page_content( diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index a10030d93b..bd99387e8d 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -39,7 +39,7 @@ class DatasourceFileMessageTransformer: conversation_id=conversation_id, ) - url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}" + url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}" yield DatasourceInvokeMessage( type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, @@ -77,7 +77,7 @@ class DatasourceFileMessageTransformer: filename=filename, ) - url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype)) + url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type)) # check if file is image if "image" in mimetype: @@ -98,7 +98,7 @@ class DatasourceFileMessageTransformer: if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) if file.type == FileType.IMAGE: yield DatasourceInvokeMessage( type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index e8256b3282..1625670165 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,12 +1,13 @@ -from collections.abc import Mapping +from collections.abc import Generator, Mapping from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, + DatasourceInvokeMessage, DatasourceProviderType, - GetWebsiteCrawlResponse, + WebsiteCrawlMessage, ) from core.plugin.impl.datasource import PluginDatasourceManager @@ -31,12 +32,12 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - def _get_website_crawl( + def get_website_crawl( self, user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetWebsiteCrawlResponse: + ) -> Generator[WebsiteCrawlMessage, None, None]: manager = PluginDatasourceManager() return manager.get_website_crawl( diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index a65efb750e..0567f1a480 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,4 +1,3 @@ -from core.datasource.__base import datasource_provider from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType diff --git a/api/core/file/enums.py b/api/core/file/enums.py index a50a651dd3..170eb4fc23 100644 --- a/api/core/file/enums.py +++ b/api/core/file/enums.py @@ -20,6 +20,7 @@ class FileTransferMethod(StrEnum): REMOTE_URL = "remote_url" LOCAL_FILE = "local_file" TOOL_FILE = "tool_file" + DATASOURCE_FILE = "datasource_file" @staticmethod def value_of(value): diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 98ee0bb11e..06ee00c688 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,11 +1,11 @@ -from collections.abc import Mapping +from collections.abc import Generator, Mapping from typing import Any from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, + OnlineDocumentPagesMessage, + WebsiteCrawlMessage, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( @@ -93,17 +93,17 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetWebsiteCrawlResponse: + ) -> Generator[WebsiteCrawlMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ datasource_provider_id = GenericProviderID(datasource_provider) - response = self._request_with_plugin_daemon_response_stream( + return self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", - GetWebsiteCrawlResponse, + WebsiteCrawlMessage, data={ "user_id": user_id, "data": { @@ -118,10 +118,6 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: - return resp - - raise Exception("No response from plugin daemon") def get_online_document_pages( self, @@ -132,7 +128,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> GetOnlineDocumentPagesResponse: + ) -> Generator[OnlineDocumentPagesMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -142,7 +138,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", - GetOnlineDocumentPagesResponse, + OnlineDocumentPagesMessage, data={ "user_id": user_id, "data": { @@ -157,10 +153,7 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: - return resp - - raise Exception("No response from plugin daemon") + yield from response def get_online_document_page_content( self, @@ -171,7 +164,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: GetOnlineDocumentPageContentRequest, provider_type: str, - ) -> GetOnlineDocumentPageContentResponse: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -181,7 +174,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content", - GetOnlineDocumentPageContentResponse, + DatasourceInvokeMessage, data={ "user_id": user_id, "data": { @@ -196,10 +189,7 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - for resp in response: - return resp - - raise Exception("No response from plugin daemon") + yield from response def validate_provider_credentials( self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any] diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 03047c0545..34a86555f7 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -188,6 +188,8 @@ class ToolInvokeMessage(BaseModel): FILE = "file" LOG = "log" BLOB_CHUNK = "blob_chunk" + WEBSITE_CRAWL = "website_crawl" + ONLINE_DOCUMENT = "online_document" type: MessageType = MessageType.TEXT """ diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 9a4939502e..0d8a4ee821 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -273,3 +273,8 @@ class AgentLogEvent(BaseAgentEvent): InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent + + +class DatasourceRunEvent(BaseModel): + status: str = Field(..., description="status") + result: dict[str, Any] = Field(..., description="result") diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 2782f2fb4c..240eeeb725 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,13 +1,17 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Generator, Mapping, Sequence from typing import Any, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, DatasourceParameter, DatasourceProviderType, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer from core.file import File from core.file.enums import FileTransferMethod, FileType from core.plugin.impl.exc import PluginDaemonClientSideError @@ -19,8 +23,11 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.tool.exc import ToolFileError from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db +from factories import file_factory from models.model import UploadFile from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey @@ -36,7 +43,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): _node_data_cls = DatasourceNodeData _node_type = NodeType.DATASOURCE - def _run(self) -> NodeRunResult: + def _run(self) -> Generator: """ Run the datasource node """ @@ -65,13 +72,15 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): datasource_type=DatasourceProviderType.value_of(datasource_type), ) except DatasourceNodeError as e: - return NodeRunResult( + yield RunCompletedEvent( + run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to get datasource runtime: {str(e)}", error_type=type(e).__name__, ) + ) # get parameters datasource_parameters = datasource_runtime.entity.parameters @@ -91,25 +100,22 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPageContentResponse = ( + online_document_result: Generator[DatasourceInvokeMessage, None, None] = ( datasource_runtime._get_online_document_page_content( user_id=self.user_id, datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), provider_type=datasource_type, ) ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - **online_document_result.result.model_dump(), - "datasource_type": datasource_type, - }, + yield from self._transform_message( + messages=online_document_result, + parameters_for_log=parameters_for_log, + datasource_info=datasource_info, ) + case DatasourceProviderType.WEBSITE_CRAWL: - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -117,7 +123,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): **datasource_info, "datasource_type": datasource_type, }, - ) + )) case DatasourceProviderType.LOCAL_FILE: related_id = datasource_info.get("related_id") if not related_id: @@ -149,7 +155,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): variable_key_list=new_key_list, variable_value=value, ) - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, @@ -157,25 +163,25 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): "file_info": datasource_info, "datasource_type": datasource_type, }, - ) + )) case _: raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") except PluginDaemonClientSideError as e: - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to transform datasource message: {str(e)}", error_type=type(e).__name__, - ) + )) except DatasourceNodeError as e: - return NodeRunResult( + yield RunCompletedEvent(run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, error=f"Failed to invoke datasource: {str(e)}", error_type=type(e).__name__, - ) + )) def _generate_parameters( self, @@ -279,3 +285,136 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): result = {node_id + "." + key: value for key, value in result.items()} return result + + + + def _transform_message( + self, + messages: Generator[DatasourceInvokeMessage, None, None], + parameters_for_log: dict[str, Any], + datasource_info: dict[str, Any], + ) -> Generator: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=messages, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json: list[dict] = [] + + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + DatasourceInvokeMessage.MessageType.IMAGE_LINK, + DatasourceInvokeMessage.MessageType.BINARY_LINK, + DatasourceInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.DATASOURCE_FILE) + else: + transfer_method = FileTransferMethod.DATASOURCE_FILE + + datasource_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"Tool file {datasource_file_id} does not exist") + + mapping = { + "datasource_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + files.append(file) + elif message.type == DatasourceInvokeMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + assert message.meta + + datasource_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"datasource file {datasource_file_id} not exists") + + mapping = { + "datasource_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.DATASOURCE_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + elif message.type == DatasourceInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + text += message.message.text + yield RunStreamChunkEvent( + chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] + ) + elif message.type == DatasourceInvokeMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage) + if self.node_type == NodeType.AGENT: + msg_metadata = message.message.json_object.pop("execution_metadata", {}) + agent_execution_metadata = { + key: value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + json.append(message.message.json_object) + elif message.type == DatasourceInvokeMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) + elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield RunStreamChunkEvent( + chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] + ) + else: + variables[variable_name] = variable_value + elif message.type == DatasourceInvokeMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"json": json, "files": files, **variables, "text": text}, + metadata={ + WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info, + }, + inputs=parameters_for_log, + ) + ) diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index b72d111f49..3ebe80f245 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, Field from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus class RunCompletedEvent(BaseModel): @@ -39,11 +38,3 @@ class RunRetryEvent(BaseModel): error: str = Field(..., description="error") retry_index: int = Field(..., description="Retry attempt number") start_at: datetime = Field(..., description="Retry start time") - - -class SingleStepRetryEvent(NodeRunResult): - """Single step retry event""" - - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY - - elapsed_time: float = Field(..., description="elapsed time") diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index c63d837106..49c8ec1e69 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,7 +1,7 @@ import datetime import logging -from collections.abc import Mapping import time +from collections.abc import Mapping from typing import Any, cast from sqlalchemy import func diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index aaecc7b989..9a37f0e51c 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -127,7 +127,7 @@ class ToolNode(BaseNode[ToolNodeData]): inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to transform tool message: {str(e)}", - error_type=type(e).__name__, + error_type=type(e).__name__, PipelineGenerator.convert_to_event_strea ) ) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 52f119936f..128041a27d 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -60,6 +60,7 @@ def build_from_mapping( FileTransferMethod.LOCAL_FILE: _build_from_local_file, FileTransferMethod.REMOTE_URL: _build_from_remote_url, FileTransferMethod.TOOL_FILE: _build_from_tool_file, + FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, } build_func = build_functions.get(transfer_method) @@ -302,6 +303,52 @@ def _build_from_tool_file( ) +def _build_from_datasource_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, +) -> File: + datasource_file = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == mapping.get("datasource_file_id"), + UploadFile.tenant_id == tenant_id, + ) + .first() + ) + + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + + detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + + specified_type = mapping.get("type") + + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + file_type = ( + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type + ) + + return File( + id=mapping.get("id"), + tenant_id=tenant_id, + filename=datasource_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=datasource_file.source_url, + related_id=datasource_file.id, + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + ) + def _is_file_valid_with_config( *, input_file_type: str, diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index d97b43d557..8915c18bd8 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -42,10 +42,6 @@ from core.workflow.constants import ( ) -class InvalidSelectorError(ValueError): - pass - - class UnsupportedSegmentTypeError(Exception): pass diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index eb1f055708..697e691224 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -4,7 +4,6 @@ from . import ( app_model_config, audio, base, - completion, conversation, dataset, document, @@ -19,7 +18,6 @@ __all__ = [ "app_model_config", "audio", "base", - "completion", "conversation", "dataset", "document", diff --git a/api/services/errors/account.py b/api/services/errors/account.py index 5aca12ffeb..4d3d150e07 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -55,7 +55,3 @@ class MemberNotInTenantError(BaseServiceError): class RoleAlreadyAssignedError(BaseServiceError): pass - - -class RateLimitExceededError(BaseServiceError): - pass diff --git a/api/services/errors/completion.py b/api/services/errors/completion.py deleted file mode 100644 index 7fc50a588e..0000000000 --- a/api/services/errors/completion.py +++ /dev/null @@ -1,5 +0,0 @@ -from services.errors.base import BaseServiceError - - -class CompletionStoppedError(BaseServiceError): - pass diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index df9fea805c..1d61677bea 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -15,13 +15,13 @@ import contexts from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, DatasourceProviderType, - GetOnlineDocumentPagesResponse, - GetWebsiteCrawlResponse, + OnlineDocumentPagesMessage, + WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin -from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult @@ -31,7 +31,7 @@ from core.workflow.entities.workflow_node_execution import ( ) from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.graph_engine.entities.event import DatasourceRunEvent, InNodeEvent from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event.event import RunCompletedEvent @@ -43,14 +43,14 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore -from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser -from models.oauth import DatasourceProvider from models.workflow import ( Workflow, + WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, - WorkflowType, WorkflowNodeExecutionModel, + WorkflowType, ) from services.dataset_service import DatasetService from services.datasource_provider_service import DatasourceProviderService @@ -468,15 +468,16 @@ class RagPipelineService: case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + website_crawl_results: list[WebsiteCrawlMessage] = [] + for website_message in datasource_runtime.get_website_crawl( user_id=account.id, datasource_parameters={"job_id": job_id}, provider_type=datasource_runtime.datasource_provider_type(), - ) + ): + website_crawl_results.append(website_message) return { - "result": [result for result in website_crawl_result.result], - "job_id": website_crawl_result.result.job_id, - "status": website_crawl_result.result.status, + "result": [result for result in website_crawl_results.result], + "status": website_crawl_results.result.status, "provider_type": datasource_node_data.get("provider_type"), } case _: @@ -485,7 +486,7 @@ class RagPipelineService: def run_datasource_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, is_published: bool - ) -> dict: + ) -> Generator[DatasourceRunEvent, None, None]: """ Run published workflow datasource """ @@ -532,29 +533,25 @@ class RagPipelineService: match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( + online_document_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_online_document_pages( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) - return { - "result": [page.model_dump() for page in online_document_result.result], - "provider_type": datasource_node_data.get("provider_type"), - } + for message in online_document_result: + yield DatasourceRunEvent( + status="success", + result=message.model_dump(), + ) case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + website_crawl_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_website_crawl( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) - return { - "result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [], - "job_id": website_crawl_result.result.job_id, - "status": website_crawl_result.result.status, - "provider_type": datasource_node_data.get("provider_type"), - } + yield from website_crawl_result case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")