Merge branch 'feat/r2' into deploy/rag-dev

feat/datasource
jyong 11 months ago
commit 1d71fd5b56

@ -0,0 +1,28 @@
name: Deploy RAG Dev
permissions:
contents: read
on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/rag-dev"
types:
- completed
jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}

@ -1,6 +1,5 @@
import logging import logging
import yaml
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session

@ -8,6 +8,7 @@ from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
import services import services
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import api
@ -453,7 +454,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource):
raise ValueError("missing datasource_type") raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.run_datasource_workflow_node( return helper.compact_generate_response(rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, pipeline=pipeline,
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=inputs,
@ -461,8 +462,7 @@ class RagPipelineDrafDatasourceNodeRunApi(Resource):
datasource_type=datasource_type, datasource_type=datasource_type,
is_published=False is_published=False
) )
)
return result
class RagPipelinePublishedNodeRunApi(Resource): class RagPipelinePublishedNodeRunApi(Resource):

@ -15,7 +15,7 @@ from core.plugin.entities.parameters import (
init_frontend_parameter, init_frontend_parameter,
) )
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolLabelEnum from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
class DatasourceProviderType(enum.StrEnum): class DatasourceProviderType(enum.StrEnum):
@ -207,12 +207,6 @@ class DatasourceInvokeFrom(Enum):
RAG_PIPELINE = "rag_pipeline" RAG_PIPELINE = "rag_pipeline"
class GetOnlineDocumentPagesRequest(BaseModel):
"""
Get online document pages request
"""
class OnlineDocumentPage(BaseModel): class OnlineDocumentPage(BaseModel):
""" """
Online document page Online document page
@ -237,7 +231,7 @@ class OnlineDocumentInfo(BaseModel):
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
class GetOnlineDocumentPagesResponse(BaseModel): class OnlineDocumentPagesMessage(BaseModel):
""" """
Get online document pages response Get online document pages response
""" """
@ -290,14 +284,42 @@ class WebSiteInfo(BaseModel):
""" """
Website info Website info
""" """
job_id: str = Field(..., description="The job id") status: Optional[str] = Field(..., description="crawl job status")
status: str = Field(..., description="The status of the job")
web_info_list: Optional[list[WebSiteInfoDetail]] = [] web_info_list: Optional[list[WebSiteInfoDetail]] = []
total: Optional[int] = Field(default=0, description="The total number of websites")
completed: Optional[int] = Field(default=0, description="The number of completed websites")
class WebsiteCrawlMessage(BaseModel):
class GetWebsiteCrawlResponse(BaseModel):
""" """
Get website crawl response Get website crawl response
""" """
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
class DatasourceMessage(ToolInvokeMessage):
pass
class DatasourceInvokeMessage(ToolInvokeMessage):
"""
Datasource Invoke Message.
"""
class WebsiteCrawlMessage(BaseModel):
"""
Website crawl message
"""
job_id: str = Field(..., description="The job id")
status: str = Field(..., description="The status of the job")
web_info_list: Optional[list[WebSiteInfoDetail]] = []
class OnlineDocumentMessage(BaseModel):
"""
Online document message
"""
result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[]) workspace_id: str = Field(..., description="The workspace id")
workspace_name: str = Field(..., description="The workspace name")
workspace_icon: str = Field(..., description="The workspace icon")
total: int = Field(..., description="The total number of documents")
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")

@ -1,14 +1,14 @@
from collections.abc import Mapping from collections.abc import Generator, Mapping
from typing import Any from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceEntity, DatasourceEntity,
DatasourceInvokeMessage,
DatasourceProviderType, DatasourceProviderType,
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse, OnlineDocumentPagesMessage,
GetOnlineDocumentPagesResponse,
) )
from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.datasource import PluginDatasourceManager
@ -38,7 +38,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
user_id: str, user_id: str,
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> GetOnlineDocumentPagesResponse: ) -> Generator[OnlineDocumentPagesMessage, None, None]:
manager = PluginDatasourceManager() manager = PluginDatasourceManager()
return manager.get_online_document_pages( return manager.get_online_document_pages(
@ -56,7 +56,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
user_id: str, user_id: str,
datasource_parameters: GetOnlineDocumentPageContentRequest, datasource_parameters: GetOnlineDocumentPageContentRequest,
provider_type: str, provider_type: str,
) -> GetOnlineDocumentPageContentResponse: ) -> Generator[DatasourceInvokeMessage, None, None]:
manager = PluginDatasourceManager() manager = PluginDatasourceManager()
return manager.get_online_document_page_content( return manager.get_online_document_page_content(

@ -39,7 +39,7 @@ class DatasourceFileMessageTransformer:
conversation_id=conversation_id, conversation_id=conversation_id,
) )
url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}" url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
yield DatasourceInvokeMessage( yield DatasourceInvokeMessage(
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
@ -77,7 +77,7 @@ class DatasourceFileMessageTransformer:
filename=filename, filename=filename,
) )
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype)) url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type))
# check if file is image # check if file is image
if "image" in mimetype: if "image" in mimetype:
@ -98,7 +98,7 @@ class DatasourceFileMessageTransformer:
if isinstance(file, File): if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE: if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE: if file.type == FileType.IMAGE:
yield DatasourceInvokeMessage( yield DatasourceInvokeMessage(
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,

@ -1,12 +1,13 @@
from collections.abc import Mapping from collections.abc import Generator, Mapping
from typing import Any from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceEntity, DatasourceEntity,
DatasourceInvokeMessage,
DatasourceProviderType, DatasourceProviderType,
GetWebsiteCrawlResponse, WebsiteCrawlMessage,
) )
from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.datasource import PluginDatasourceManager
@ -31,12 +32,12 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
self.icon = icon self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier self.plugin_unique_identifier = plugin_unique_identifier
def _get_website_crawl( def get_website_crawl(
self, self,
user_id: str, user_id: str,
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> GetWebsiteCrawlResponse: ) -> Generator[WebsiteCrawlMessage, None, None]:
manager = PluginDatasourceManager() manager = PluginDatasourceManager()
return manager.get_website_crawl( return manager.get_website_crawl(

@ -1,4 +1,3 @@
from core.datasource.__base import datasource_provider
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType

@ -20,6 +20,7 @@ class FileTransferMethod(StrEnum):
REMOTE_URL = "remote_url" REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file" LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file" TOOL_FILE = "tool_file"
DATASOURCE_FILE = "datasource_file"
@staticmethod @staticmethod
def value_of(value): def value_of(value):

@ -1,11 +1,11 @@
from collections.abc import Mapping from collections.abc import Generator, Mapping
from typing import Any from typing import Any
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceInvokeMessage,
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse, OnlineDocumentPagesMessage,
GetOnlineDocumentPagesResponse, WebsiteCrawlMessage,
GetWebsiteCrawlResponse,
) )
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import ( from core.plugin.entities.plugin_daemon import (
@ -93,17 +93,17 @@ class PluginDatasourceManager(BasePluginClient):
credentials: dict[str, Any], credentials: dict[str, Any],
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> GetWebsiteCrawlResponse: ) -> Generator[WebsiteCrawlMessage, None, None]:
""" """
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
""" """
datasource_provider_id = GenericProviderID(datasource_provider) datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream( return self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl",
GetWebsiteCrawlResponse, WebsiteCrawlMessage,
data={ data={
"user_id": user_id, "user_id": user_id,
"data": { "data": {
@ -118,10 +118,6 @@ class PluginDatasourceManager(BasePluginClient):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
for resp in response:
return resp
raise Exception("No response from plugin daemon")
def get_online_document_pages( def get_online_document_pages(
self, self,
@ -132,7 +128,7 @@ class PluginDatasourceManager(BasePluginClient):
credentials: dict[str, Any], credentials: dict[str, Any],
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> GetOnlineDocumentPagesResponse: ) -> Generator[OnlineDocumentPagesMessage, None, None]:
""" """
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
""" """
@ -142,7 +138,7 @@ class PluginDatasourceManager(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
GetOnlineDocumentPagesResponse, OnlineDocumentPagesMessage,
data={ data={
"user_id": user_id, "user_id": user_id,
"data": { "data": {
@ -157,10 +153,7 @@ class PluginDatasourceManager(BasePluginClient):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
for resp in response: yield from response
return resp
raise Exception("No response from plugin daemon")
def get_online_document_page_content( def get_online_document_page_content(
self, self,
@ -171,7 +164,7 @@ class PluginDatasourceManager(BasePluginClient):
credentials: dict[str, Any], credentials: dict[str, Any],
datasource_parameters: GetOnlineDocumentPageContentRequest, datasource_parameters: GetOnlineDocumentPageContentRequest,
provider_type: str, provider_type: str,
) -> GetOnlineDocumentPageContentResponse: ) -> Generator[DatasourceInvokeMessage, None, None]:
""" """
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
""" """
@ -181,7 +174,7 @@ class PluginDatasourceManager(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content",
GetOnlineDocumentPageContentResponse, DatasourceInvokeMessage,
data={ data={
"user_id": user_id, "user_id": user_id,
"data": { "data": {
@ -196,10 +189,7 @@ class PluginDatasourceManager(BasePluginClient):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
for resp in response: yield from response
return resp
raise Exception("No response from plugin daemon")
def validate_provider_credentials( def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any] self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any]

@ -188,6 +188,8 @@ class ToolInvokeMessage(BaseModel):
FILE = "file" FILE = "file"
LOG = "log" LOG = "log"
BLOB_CHUNK = "blob_chunk" BLOB_CHUNK = "blob_chunk"
WEBSITE_CRAWL = "website_crawl"
ONLINE_DOCUMENT = "online_document"
type: MessageType = MessageType.TEXT type: MessageType = MessageType.TEXT
""" """

@ -273,3 +273,8 @@ class AgentLogEvent(BaseAgentEvent):
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent
class DatasourceRunEvent(BaseModel):
status: str = Field(..., description="status")
result: dict[str, Any] = Field(..., description="result")

@ -1,13 +1,17 @@
from collections.abc import Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceInvokeMessage,
DatasourceParameter, DatasourceParameter,
DatasourceProviderType, DatasourceProviderType,
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse,
) )
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.file import File from core.file import File
from core.file.enums import FileTransferMethod, FileType from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
@ -19,8 +23,11 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile from models.model import UploadFile
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
@ -36,7 +43,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
_node_data_cls = DatasourceNodeData _node_data_cls = DatasourceNodeData
_node_type = NodeType.DATASOURCE _node_type = NodeType.DATASOURCE
def _run(self) -> NodeRunResult: def _run(self) -> Generator:
""" """
Run the datasource node Run the datasource node
""" """
@ -65,13 +72,15 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
datasource_type=DatasourceProviderType.value_of(datasource_type), datasource_type=DatasourceProviderType.value_of(datasource_type),
) )
except DatasourceNodeError as e: except DatasourceNodeError as e:
return NodeRunResult( yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs={}, inputs={},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}", error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) )
)
# get parameters # get parameters
datasource_parameters = datasource_runtime.entity.parameters datasource_parameters = datasource_runtime.entity.parameters
@ -91,25 +100,22 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
match datasource_type: match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT: case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPageContentResponse = ( online_document_result: Generator[DatasourceInvokeMessage, None, None] = (
datasource_runtime._get_online_document_page_content( datasource_runtime._get_online_document_page_content(
user_id=self.user_id, user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
provider_type=datasource_type, provider_type=datasource_type,
) )
) )
return NodeRunResult( yield from self._transform_message(
status=WorkflowNodeExecutionStatus.SUCCEEDED, messages=online_document_result,
inputs=parameters_for_log, parameters_for_log=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, datasource_info=datasource_info,
outputs={
**online_document_result.result.model_dump(),
"datasource_type": datasource_type,
},
) )
case DatasourceProviderType.WEBSITE_CRAWL: case DatasourceProviderType.WEBSITE_CRAWL:
return NodeRunResult( yield RunCompletedEvent(run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
@ -117,7 +123,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
**datasource_info, **datasource_info,
"datasource_type": datasource_type, "datasource_type": datasource_type,
}, },
) ))
case DatasourceProviderType.LOCAL_FILE: case DatasourceProviderType.LOCAL_FILE:
related_id = datasource_info.get("related_id") related_id = datasource_info.get("related_id")
if not related_id: if not related_id:
@ -149,7 +155,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
variable_key_list=new_key_list, variable_key_list=new_key_list,
variable_value=value, variable_value=value,
) )
return NodeRunResult( yield RunCompletedEvent(run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
@ -157,25 +163,25 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
"file_info": datasource_info, "file_info": datasource_info,
"datasource_type": datasource_type, "datasource_type": datasource_type,
}, },
) ))
case _: case _:
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}") raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
return NodeRunResult( yield RunCompletedEvent(run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}", error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) ))
except DatasourceNodeError as e: except DatasourceNodeError as e:
return NodeRunResult( yield RunCompletedEvent(run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}", error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) ))
def _generate_parameters( def _generate_parameters(
self, self,
@ -279,3 +285,136 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
result = {node_id + "." + key: value for key, value in result.items()} result = {node_id + "." + key: value for key, value in result.items()}
return result return result
def _transform_message(
self,
messages: Generator[DatasourceInvokeMessage, None, None],
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
DatasourceInvokeMessage.MessageType.BINARY_LINK,
DatasourceInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.DATASOURCE_FILE)
else:
transfer_method = FileTransferMethod.DATASOURCE_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"datasource_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"datasource_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.DATASOURCE_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
elif message.type == DatasourceInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
elif message.type == DatasourceInvokeMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = {
key: value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
json.append(message.message.json_object)
elif message.type == DatasourceInvokeMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == DatasourceInvokeMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"json": json, "files": files, **variables, "text": text},
metadata={
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
},
inputs=parameters_for_log,
)
)

@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
class RunCompletedEvent(BaseModel): class RunCompletedEvent(BaseModel):
@ -39,11 +38,3 @@ class RunRetryEvent(BaseModel):
error: str = Field(..., description="error") error: str = Field(..., description="error")
retry_index: int = Field(..., description="Retry attempt number") retry_index: int = Field(..., description="Retry attempt number")
start_at: datetime = Field(..., description="Retry start time") start_at: datetime = Field(..., description="Retry start time")
class SingleStepRetryEvent(NodeRunResult):
"""Single step retry event"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY
elapsed_time: float = Field(..., description="elapsed time")

@ -1,7 +1,7 @@
import datetime import datetime
import logging import logging
from collections.abc import Mapping
import time import time
from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
from sqlalchemy import func from sqlalchemy import func

@ -127,7 +127,7 @@ class ToolNode(BaseNode[ToolNodeData]):
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to transform tool message: {str(e)}", error=f"Failed to transform tool message: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__, PipelineGenerator.convert_to_event_strea
) )
) )

@ -60,6 +60,7 @@ def build_from_mapping(
FileTransferMethod.LOCAL_FILE: _build_from_local_file, FileTransferMethod.LOCAL_FILE: _build_from_local_file,
FileTransferMethod.REMOTE_URL: _build_from_remote_url, FileTransferMethod.REMOTE_URL: _build_from_remote_url,
FileTransferMethod.TOOL_FILE: _build_from_tool_file, FileTransferMethod.TOOL_FILE: _build_from_tool_file,
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
} }
build_func = build_functions.get(transfer_method) build_func = build_functions.get(transfer_method)
@ -302,6 +303,52 @@ def _build_from_tool_file(
) )
def _build_from_datasource_file(
*,
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
datasource_file = (
db.session.query(UploadFile)
.filter(
UploadFile.id == mapping.get("datasource_file_id"),
UploadFile.tenant_id == tenant_id,
)
.first()
)
if datasource_file is None:
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
return File(
id=mapping.get("id"),
tenant_id=tenant_id,
filename=datasource_file.name,
type=file_type,
transfer_method=transfer_method,
remote_url=datasource_file.source_url,
related_id=datasource_file.id,
extension=extension,
mime_type=datasource_file.mime_type,
size=datasource_file.size,
storage_key=datasource_file.key,
)
def _is_file_valid_with_config( def _is_file_valid_with_config(
*, *,
input_file_type: str, input_file_type: str,

@ -42,10 +42,6 @@ from core.workflow.constants import (
) )
class InvalidSelectorError(ValueError):
pass
class UnsupportedSegmentTypeError(Exception): class UnsupportedSegmentTypeError(Exception):
pass pass

@ -4,7 +4,6 @@ from . import (
app_model_config, app_model_config,
audio, audio,
base, base,
completion,
conversation, conversation,
dataset, dataset,
document, document,
@ -19,7 +18,6 @@ __all__ = [
"app_model_config", "app_model_config",
"audio", "audio",
"base", "base",
"completion",
"conversation", "conversation",
"dataset", "dataset",
"document", "document",

@ -55,7 +55,3 @@ class MemberNotInTenantError(BaseServiceError):
class RoleAlreadyAssignedError(BaseServiceError): class RoleAlreadyAssignedError(BaseServiceError):
pass pass
class RateLimitExceededError(BaseServiceError):
pass

@ -1,5 +0,0 @@
from services.errors.base import BaseServiceError
class CompletionStoppedError(BaseServiceError):
pass

@ -15,13 +15,13 @@ import contexts
from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceInvokeMessage,
DatasourceProviderType, DatasourceProviderType,
GetOnlineDocumentPagesResponse, OnlineDocumentPagesMessage,
GetWebsiteCrawlResponse, WebsiteCrawlMessage,
) )
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
@ -31,7 +31,7 @@ from core.workflow.entities.workflow_node_execution import (
) )
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.graph_engine.entities.event import DatasourceRunEvent, InNodeEvent
from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event.event import RunCompletedEvent
@ -43,14 +43,14 @@ from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account from models.account import Account
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
from models.oauth import DatasourceProvider
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom, WorkflowNodeExecutionTriggeredFrom,
WorkflowRun, WorkflowRun,
WorkflowType, WorkflowNodeExecutionModel, WorkflowType,
) )
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
@ -468,15 +468,16 @@ class RagPipelineService:
case DatasourceProviderType.WEBSITE_CRAWL: case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( website_crawl_results: list[WebsiteCrawlMessage] = []
for website_message in datasource_runtime.get_website_crawl(
user_id=account.id, user_id=account.id,
datasource_parameters={"job_id": job_id}, datasource_parameters={"job_id": job_id},
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) ):
website_crawl_results.append(website_message)
return { return {
"result": [result for result in website_crawl_result.result], "result": [result for result in website_crawl_results.result],
"job_id": website_crawl_result.result.job_id, "status": website_crawl_results.result.status,
"status": website_crawl_result.result.status,
"provider_type": datasource_node_data.get("provider_type"), "provider_type": datasource_node_data.get("provider_type"),
} }
case _: case _:
@ -485,7 +486,7 @@ class RagPipelineService:
def run_datasource_workflow_node( def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,
is_published: bool is_published: bool
) -> dict: ) -> Generator[DatasourceRunEvent, None, None]:
""" """
Run published workflow datasource Run published workflow datasource
""" """
@ -532,29 +533,25 @@ class RagPipelineService:
match datasource_type: match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT: case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( online_document_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_online_document_pages(
user_id=account.id, user_id=account.id,
datasource_parameters=user_inputs, datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) )
return { for message in online_document_result:
"result": [page.model_dump() for page in online_document_result.result], yield DatasourceRunEvent(
"provider_type": datasource_node_data.get("provider_type"), status="success",
} result=message.model_dump(),
)
case DatasourceProviderType.WEBSITE_CRAWL: case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( website_crawl_result: Generator[DatasourceInvokeMessage, None, None] = datasource_runtime._get_website_crawl(
user_id=account.id, user_id=account.id,
datasource_parameters=user_inputs, datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) )
return { yield from website_crawl_result
"result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [],
"job_id": website_crawl_result.result.job_id,
"status": website_crawl_result.result.status,
"provider_type": datasource_node_data.get("provider_type"),
}
case _: case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")

Loading…
Cancel
Save