From e165f4a1021617b7556d857949aceb5734d478e7 Mon Sep 17 00:00:00 2001 From: Dongyu Li <544104925@qq.com> Date: Tue, 24 Jun 2025 17:14:16 +0800 Subject: [PATCH] feat(datasource): add datasource content preview api --- api/controllers/console/__init__.py | 1 + .../datasource_content_preview.py | 52 ++++++++++ api/services/rag_pipeline/rag_pipeline.py | 96 ++++++++++++++++++- 3 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index f17c28dcd4..9d9023f59c 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -89,6 +89,7 @@ from .datasets.rag_pipeline import ( rag_pipeline_datasets, rag_pipeline_import, rag_pipeline_workflow, + datasource_content_preview ) # Import explore controllers diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py new file mode 100644 index 0000000000..30836b3da1 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -0,0 +1,52 @@ +from flask_restful import ( # type: ignore + Resource, # type: ignore + reqparse, +) +from werkzeug.exceptions import Forbidden +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import setup_required, account_initialization_required +from libs.login import login_required, current_user +from models import Account +from models.dataset import Pipeline +from controllers.console import api +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class DataSourceContentPreviewApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run datasource content preview + """ + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs is None: + raise ValueError("missing inputs") + datasource_type = args.get("datasource_type") + if datasource_type is None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + return rag_pipeline_service.run_datasource_node_preview( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=True, + ) + +api.add_resource( + DataSourceContentPreviewApi, + "/rag/pipelines//workflows/published/datasource/nodes//preview" +) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 333d559bf5..842676e29a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,7 +5,7 @@ import threading import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime -from typing import Any, Optional, cast +from typing import Any, Optional, cast, Mapping from uuid import uuid4 from flask_login import current_user @@ -18,7 +18,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( DatasourceProviderType, OnlineDocumentPagesMessage, - WebsiteCrawlMessage, + WebsiteCrawlMessage, DatasourceMessage, GetOnlineDocumentPageContentRequest, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin @@ -544,6 +544,98 @@ class RagPipelineService: logger.exception("Error in run_datasource_workflow_node.") yield DatasourceErrorEvent(error=str(e)).model_dump() + def run_datasource_node_preview( + self, + pipeline: Pipeline, + node_id: str, + user_inputs: dict, + account: Account, + datasource_type: str, + is_published: bool, + ) -> Mapping[str, Any]: + """ + Run published workflow datasource + """ + try: + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + for key, value in datasource_parameters.items(): + if not user_inputs.get(key): + user_inputs[key] = value["value"] + + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: Generator[DatasourceMessage, None, None] = ( + datasource_runtime.get_online_document_page_content( + user_id=account.id, + datasource_parameters=GetOnlineDocumentPageContentRequest( + workspace_id=user_inputs.get("workspace_id"), + page_id=user_inputs.get("page_id"), + type=user_inputs.get("type"), + ), + provider_type=datasource_type, + ) + ) + try: + variables: dict[str, Any] = {} + for message in online_document_result: + if message.type == DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + else: + variables[variable_name] = variable_value + return variables + except Exception as e: + logger.exception("Error during get online document content.") + raise RuntimeError(str(e)) + #TODO Online Drive + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + except Exception as e: + logger.exception("Error in run_datasource_node_preview.") + raise RuntimeError(str(e)) + def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: