From 5ccb8d9736eb3948c81f232920e0f06a2141660d Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 13 Jun 2025 18:22:15 +0800 Subject: [PATCH] feat: online document --- .../entities/datasource_entities.py | 39 +++++++++++++++---- .../online_document/online_document_plugin.py | 9 ++--- api/core/plugin/impl/datasource.py | 7 ++-- .../nodes/datasource/datasource_node.py | 5 +-- api/services/rag_pipeline/rag_pipeline.py | 4 +- 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index dd65c85cbc..b9a0c1f150 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -15,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolLabelEnum +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum class DatasourceProviderType(enum.StrEnum): @@ -207,12 +207,6 @@ class DatasourceInvokeFrom(Enum): RAG_PIPELINE = "rag_pipeline" -class GetOnlineDocumentPagesRequest(BaseModel): - """ - Get online document pages request - """ - - class OnlineDocumentPage(BaseModel): """ Online document page @@ -237,7 +231,7 @@ class OnlineDocumentInfo(BaseModel): pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") -class GetOnlineDocumentPagesResponse(BaseModel): +class OnlineDocumentPagesMessage(BaseModel): """ Get online document pages response """ @@ -300,3 +294,32 @@ class WebsiteCrawlMessage(BaseModel): Get website crawl response """ result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0) + +class DatasourceMessage(ToolInvokeMessage): + pass + + +class DatasourceInvokeMessage(ToolInvokeMessage): + """ + Datasource Invoke Message. + """ + + class WebsiteCrawlMessage(BaseModel): + """ + Website crawl message + """ + + job_id: str = Field(..., description="The job id") + status: str = Field(..., description="The status of the job") + web_info_list: Optional[list[WebSiteInfoDetail]] = [] + + class OnlineDocumentMessage(BaseModel): + """ + Online document message + """ + + workspace_id: str = Field(..., description="The workspace id") + workspace_name: str = Field(..., description="The workspace name") + workspace_icon: str = Field(..., description="The workspace icon") + total: int = Field(..., description="The total number of documents") + pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 2ab60cae1e..db73d9a64b 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,5 +1,5 @@ -from collections.abc import Mapping -from typing import Any, Generator +from collections.abc import Generator, Mapping +from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -8,8 +8,7 @@ from core.datasource.entities.datasource_entities import ( DatasourceInvokeMessage, DatasourceProviderType, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesResponse, + OnlineDocumentPagesMessage, ) from core.plugin.impl.datasource import PluginDatasourceManager @@ -39,7 +38,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): user_id: str, datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[OnlineDocumentPagesMessage, None, None]: manager = PluginDatasourceManager() return manager.get_online_document_pages( diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 54325a545f..06ee00c688 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -4,8 +4,7 @@ from typing import Any from core.datasource.entities.datasource_entities import ( DatasourceInvokeMessage, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesResponse, + OnlineDocumentPagesMessage, WebsiteCrawlMessage, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -129,7 +128,7 @@ class PluginDatasourceManager(BasePluginClient): credentials: dict[str, Any], datasource_parameters: Mapping[str, Any], provider_type: str, - ) -> Generator[DatasourceInvokeMessage, None, None]: + ) -> Generator[OnlineDocumentPagesMessage, None, None]: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -139,7 +138,7 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", - DatasourceInvokeMessage, + OnlineDocumentPagesMessage, data={ "user_id": user_id, "data": { diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index bd4a6e3a56..240eeeb725 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,5 +1,5 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Generator, cast +from collections.abc import Generator, Mapping, Sequence +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -9,7 +9,6 @@ from core.datasource.entities.datasource_entities import ( DatasourceParameter, DatasourceProviderType, GetOnlineDocumentPageContentRequest, - GetOnlineDocumentPageContentResponse, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 43b68b3b97..7af607a96b 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -16,7 +16,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceProviderType, - GetOnlineDocumentPagesResponse, + OnlineDocumentPagesMessage, WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin @@ -532,7 +532,7 @@ class RagPipelineService: match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages( + online_document_result: OnlineDocumentPagesMessage = datasource_runtime._get_online_document_pages( user_id=account.id, datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(),