diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index fa4130b762..c0406940a7 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -39,9 +39,9 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.account import Account from models.dataset import Pipeline -from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.errors.llm import InvokeRateLimitError +from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError @@ -170,7 +170,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource): args = parser.parse_args() try: - response = AppGenerateService.generate_single_iteration( + response = PipelineGenerateService.generate_single_iteration( pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True ) @@ -207,7 +207,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): args = parser.parse_args() try: - response = AppGenerateService.generate_single_loop( + response = PipelineGenerateService.generate_single_loop( pipeline=pipeline, user=current_user, node_id=node_id, args=args, streaming=True ) @@ -241,11 +241,12 @@ class DraftRagPipelineRunApi(Resource): parser = reqparse.RequestParser() parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("datasource_info", type=list, required=True, location="json") args = parser.parse_args() try: - response = AppGenerateService.generate( + response = PipelineGenerateService.generate( pipeline=pipeline, user=current_user, args=args, @@ -258,7 +259,73 @@ class DraftRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) +class PublishedRagPipelineRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline): + """ + Run published workflow + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + parser.add_argument("datasource_info", type=list, required=True, location="json") + args = parser.parse_args() + + try: + response = PipelineGenerateService.generate( + pipeline=pipeline, + user=current_user, + args=args, + invoke_from=InvokeFrom.PUBLISHED, + streaming=True, + ) + + return helper.compact_generate_response(response) + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + + class RagPipelineDatasourceNodeRunApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run rag pipeline datasource + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + + rag_pipeline_service = RagPipelineService() + result = rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user + ) + + return result + + +class RagPipelinePublishedNodeRunApi(Resource): @setup_required @login_required @account_initialization_required @@ -283,7 +350,7 @@ class RagPipelineDatasourceNodeRunApi(Resource): raise ValueError("missing inputs") rag_pipeline_service = RagPipelineService() - workflow_node_execution = rag_pipeline_service.run_datasource_workflow_node( + workflow_node_execution = rag_pipeline_service.run_published_workflow_node( pipeline=pipeline, node_id=node_id, user_inputs=inputs, account=current_user ) @@ -354,7 +421,8 @@ class PublishedRagPipelineApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + if not pipeline.is_published: + return None # fetch published workflow by pipeline rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline) @@ -397,10 +465,8 @@ class PublishedRagPipelineApi(Resource): marked_name=args.marked_name or "", marked_comment=args.marked_comment or "", ) - + pipeline.is_published = True pipeline.workflow_id = workflow.id - db.session.commit() - workflow_created_at = TimestampField().format(workflow.created_at) session.commit() @@ -617,7 +683,7 @@ class RagPipelineByIdApi(Resource): return None, 204 -class RagPipelineSecondStepApi(Resource): +class PublishedRagPipelineSecondStepApi(Resource): @setup_required @login_required @account_initialization_required @@ -632,9 +698,28 @@ class RagPipelineSecondStepApi(Resource): node_id = request.args.get("node_id", required=True, type=str) rag_pipeline_service = RagPipelineService() - variables = rag_pipeline_service.get_second_step_parameters( - pipeline=pipeline, node_id=node_id - ) + variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id) + return { + "variables": variables, + } + + +class DraftRagPipelineSecondStepApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def get(self, pipeline: Pipeline): + """ + Get second step parameters of rag pipeline + """ + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + node_id = request.args.get("node_id", required=True, type=str) + + rag_pipeline_service = RagPipelineService() + variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id) return { "variables": variables, } @@ -732,15 +817,21 @@ api.add_resource( RagPipelineDraftNodeRunApi, "/rag/pipelines//workflows/draft/nodes//run", ) -# api.add_resource( -# RagPipelinePublishedNodeRunApi, -# "/rag/pipelines//workflows/published/nodes//run", -# ) +api.add_resource( + RagPipelineDatasourceNodeRunApi, + "/rag/pipelines//workflows/datasource/nodes//run", +) api.add_resource( RagPipelineDraftRunIterationNodeApi, "/rag/pipelines//workflows/draft/iteration/nodes//run", ) + +api.add_resource( + RagPipelinePublishedNodeRunApi, + "/rag/pipelines//workflows/published/nodes//run", +) + api.add_resource( RagPipelineDraftRunLoopNodeApi, "/rag/pipelines//workflows/draft/loop/nodes//run", @@ -762,7 +853,6 @@ api.add_resource( DefaultRagPipelineBlockConfigApi, "/rag/pipelines//workflows/default-workflow-block-configs/", ) - api.add_resource( RagPipelineByIdApi, "/rag/pipelines//workflows/", @@ -784,6 +874,10 @@ api.add_resource( "/rag/pipelines/datasource-plugins", ) api.add_resource( - RagPipelineSecondStepApi, - "/rag/pipelines//workflows/processing/paramters", + PublishedRagPipelineSecondStepApi, + "/rag/pipelines//workflows/published/processing/paramters", +) +api.add_resource( + DraftRagPipelineSecondStepApi, + "/rag/pipelines//workflows/draft/processing/paramters", ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 8ae52131f2..48e8ca5594 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -283,7 +283,7 @@ class AppConfig(BaseModel): tenant_id: str app_id: str app_mode: AppMode - additional_features: AppAdditionalFeatures + additional_features: Optional[AppAdditionalFeatures] = None variables: list[VariableEntity] = [] sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None diff --git a/api/core/app/apps/pipeline/__init__.py b/api/core/app/apps/pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py new file mode 100644 index 0000000000..10ec73a7d2 --- /dev/null +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -0,0 +1,95 @@ +from collections.abc import Generator +from typing import cast + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.task_entities import ( + AppStreamResponse, + ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) + + +class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): + _blocking_response_type = WorkflowAppBlockingResponse + + @classmethod + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking full response. + :param blocking_response: blocking response + :return: + """ + return dict(blocking_response.to_dict()) + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + """ + Convert blocking simple response. + :param blocking_response: blocking response + :return: + """ + return cls.convert_blocking_full_response(blocking_response) + + @classmethod + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + """ + Convert stream full response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(WorkflowAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield response_chunk + + @classmethod + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + """ + Convert stream simple response. + :param stream_response: stream response + :return: + """ + for chunk in stream_response: + chunk = cast(WorkflowAppStreamResponse, chunk) + sub_stream_response = chunk.stream_response + + if isinstance(sub_stream_response, PingStreamResponse): + yield "ping" + continue + + response_chunk = { + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, + } + + if isinstance(sub_stream_response, ErrorStreamResponse): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) + else: + response_chunk.update(sub_stream_response.to_dict()) + yield response_chunk diff --git a/api/core/app/apps/pipeline/pipeline_config_manager.py b/api/core/app/apps/pipeline/pipeline_config_manager.py new file mode 100644 index 0000000000..ddf87eacbb --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_config_manager.py @@ -0,0 +1,63 @@ +from core.app.app_config.base_app_config_manager import BaseAppConfigManager +from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager +from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager +from models.dataset import Pipeline +from models.model import AppMode +from models.workflow import Workflow + + +class PipelineConfig(WorkflowUIBasedAppConfig): + """ + Pipeline Config Entity. + """ + + pass + + +class PipelineConfigManager(BaseAppConfigManager): + @classmethod + def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig: + pipeline_config = PipelineConfig( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + app_mode=AppMode.RAG_PIPELINE, + workflow_id=workflow.id, + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + ) + + return pipeline_config + + @classmethod + def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + """ + Validate for pipeline config + + :param tenant_id: tenant id + :param config: app model config args + :param only_structure_validate: only validate the structure of the config + """ + related_config_keys = [] + + # file upload validation + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) + related_config_keys.extend(current_related_config_keys) + + # text_to_speech + config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) + related_config_keys.extend(current_related_config_keys) + + # moderation validation + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate + ) + related_config_keys.extend(current_related_config_keys) + + related_config_keys = list(set(related_config_keys)) + + # Filter out extra parameters + filtered_config = {key: config.get(key) for key in related_config_keys} + + return filtered_config diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py new file mode 100644 index 0000000000..1e880c700c --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -0,0 +1,496 @@ +import contextvars +import datetime +import json +import logging +import random +import threading +import time +import uuid +from collections.abc import Generator, Mapping +from typing import Any, Literal, Optional, Union, overload + +from flask import Flask, current_app +from pydantic import ValidationError +from sqlalchemy.orm import sessionmaker + +import contexts +from configs import dify_config +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager +from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager +from core.app.apps.pipeline.pipeline_runner import PipelineRunner +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity +from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.model_runtime.errors.invoke import InvokeAuthorizationError +from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from extensions.ext_database import db +from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom +from models.dataset import Document, Pipeline +from services.dataset_service import DocumentService + +logger = logging.getLogger(__name__) + + +class PipelineGenerator(BaseAppGenerator): + @overload + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[True], + call_depth: int, + workflow_thread_pool_id: Optional[str], + ) -> Generator[Mapping | str, None, None]: ... + + @overload + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: Literal[False], + call_depth: int, + workflow_thread_pool_id: Optional[str], + ) -> Mapping[str, Any]: ... + + @overload + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool, + call_depth: int, + workflow_thread_pool_id: Optional[str], + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... + + def generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: + # convert to app config + pipeline_config = PipelineConfigManager.get_pipeline_config( + pipeline=pipeline, + workflow=workflow, + ) + + inputs: Mapping[str, Any] = args["inputs"] + datasource_type: str = args["datasource_type"] + datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + + for datasource_info in datasource_info_list: + workflow_run_id = str(uuid.uuid4()) + document_id = None + if invoke_from == InvokeFrom.PUBLISHED: + position = DocumentService.get_documents_position(pipeline.dataset_id) + document = self._build_document( + tenant_id=pipeline.tenant_id, + dataset_id=pipeline.dataset_id, + built_in_field_enabled=pipeline.dataset.built_in_field_enabled, + datasource_type=datasource_type, + datasource_info=datasource_info, + created_from="rag-pipeline", + position=position, + account=user, + batch=batch, + document_form=pipeline.dataset.doc_form, + ) + db.session.add(document) + db.session.commit() + document_id = document.id + # init application generate entity + application_generate_entity = RagPipelineGenerateEntity( + task_id=str(uuid.uuid4()), + pipline_config=pipeline_config, + datasource_type=datasource_type, + datasource_info=datasource_info, + dataset_id=pipeline.dataset_id, + batch=batch, + document_id=document_id, + inputs=self._prepare_user_inputs( + user_inputs=inputs, + variables=pipeline_config.variables, + tenant_id=pipeline.tenant_id, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + ), + files=[], + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + call_depth=call_depth, + workflow_run_id=workflow_run_id, + ) + + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + return self._generate( + pipeline=pipeline, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + def _generate( + self, + *, + pipeline: Pipeline, + workflow: Workflow, + user: Union[Account, EndUser], + application_generate_entity: RagPipelineGenerateEntity, + invoke_from: InvokeFrom, + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + streaming: bool = True, + workflow_thread_pool_id: Optional[str] = None, + ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param application_generate_entity: application generate entity + :param invoke_from: invoke from source + :param workflow_node_execution_repository: repository for workflow node execution + :param streaming: is stream + :param workflow_thread_pool_id: workflow thread pool id + """ + # init queue manager + queue_manager = PipelineQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=pipeline.mode, + ) + + # new thread + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": contextvars.copy_context(), + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) + + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + workflow_node_execution_repository=workflow_node_execution_repository, + stream=streaming, + ) + + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + + def single_iteration_generate( + self, + app_model: App, + workflow: Workflow, + node_id: str, + user: Account | EndUser, + args: Mapping[str, Any], + streaming: bool = True, + ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param node_id: the node id + :param user: account or end user + :param args: request args + :param streaming: is streamed + """ + if not node_id: + raise ValueError("node_id is required") + + if args.get("inputs") is None: + raise ValueError("inputs is required") + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user.id, + stream=streaming, + invoke_from=InvokeFrom.DEBUGGER, + extras={"auto_generate_conversation_name": False}, + single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( + node_id=node_id, inputs=args["inputs"] + ), + workflow_run_id=str(uuid.uuid4()), + ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + ) + + def single_loop_generate( + self, + app_model: App, + workflow: Workflow, + node_id: str, + user: Account | EndUser, + args: Mapping[str, Any], + streaming: bool = True, + ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param node_id: the node id + :param user: account or end user + :param args: request args + :param streaming: is streamed + """ + if not node_id: + raise ValueError("node_id is required") + + if args.get("inputs") is None: + raise ValueError("inputs is required") + + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user.id, + stream=streaming, + invoke_from=InvokeFrom.DEBUGGER, + extras={"auto_generate_conversation_name": False}, + single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), + workflow_run_id=str(uuid.uuid4()), + ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create workflow node execution repository + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + ) + + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: RagPipelineGenerateEntity, + queue_manager: AppQueueManager, + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: + """ + Generate worker in a new thread. + :param flask_app: Flask app + :param application_generate_entity: application generate entity + :param queue_manager: queue manager + :param workflow_thread_pool_id: workflow thread pool id + :return: + """ + for var, val in context.items(): + var.set(val) + with flask_app.app_context(): + try: + # workflow app + runner = PipelineRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + runner.run() + except GenerateTaskStoppedError: + pass + except InvokeAuthorizationError: + queue_manager.publish_error( + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER + ) + except ValidationError as e: + logger.exception("Validation Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except ValueError as e: + if dify_config.DEBUG: + logger.exception("Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + except Exception as e: + logger.exception("Unknown Error when generating") + queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) + finally: + db.session.close() + + def _handle_response( + self, + application_generate_entity: RagPipelineGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + workflow_node_execution_repository: WorkflowNodeExecutionRepository, + stream: bool = False, + ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: + """ + Handle response. + :param application_generate_entity: application generate entity + :param workflow: workflow + :param queue_manager: queue manager + :param user: account or end user + :param stream: is stream + :param workflow_node_execution_repository: optional repository for workflow node execution + :return: + """ + # init generate task pipeline + generate_task_pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + stream=stream, + workflow_node_execution_repository=workflow_node_execution_repository, + ) + + try: + return generate_task_pipeline.process() + except ValueError as e: + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error + raise GenerateTaskStoppedError() + else: + logger.exception( + f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}" + ) + raise e + + def _build_document( + self, + tenant_id: str, + dataset_id: str, + built_in_field_enabled: bool, + datasource_type: str, + datasource_info: Mapping[str, Any], + created_from: str, + position: int, + account: Account, + batch: str, + document_form: str, + ): + if datasource_type == "local_file": + name = datasource_info["name"] + elif datasource_type == "online_document": + name = datasource_info["page_title"] + elif datasource_type == "website_crawl": + name = datasource_info["title"] + else: + raise ValueError(f"Unsupported datasource type: {datasource_type}") + + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=position, + data_source_type=datasource_type, + data_source_info=json.dumps(datasource_info), + batch=batch, + name=name, + created_from=created_from, + created_by=account.id, + doc_form=document_form, + ) + doc_metadata = {} + if built_in_field_enabled: + doc_metadata = { + BuiltInField.document_name: name, + BuiltInField.uploader: account.name, + BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), + BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), + BuiltInField.source: datasource_type, + } + if doc_metadata: + document.doc_metadata = doc_metadata + return document diff --git a/api/core/app/apps/pipeline/pipeline_queue_manager.py b/api/core/app/apps/pipeline/pipeline_queue_manager.py new file mode 100644 index 0000000000..d0aeac8a9c --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_queue_manager.py @@ -0,0 +1,44 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueErrorEvent, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowSucceededEvent, + WorkflowQueueMessage, +) + + +class PipelineQueueManager(AppQueueManager): + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: + super().__init__(task_id, user_id, invoke_from) + + self._app_mode = app_mode + + def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: + """ + Publish event to queue + :param event: + :param pub_from: + :return: + """ + message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) + + self._q.put(message) + + if isinstance( + event, + QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent + | QueueWorkflowPartialSuccessEvent, + ): + self.stop_listen() + + if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py new file mode 100644 index 0000000000..1395a47d88 --- /dev/null +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -0,0 +1,154 @@ +import logging +from typing import Optional, cast + +from configs import dify_config +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import ( + InvokeFrom, + RagPipelineGenerateEntity, +) +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.dataset import Pipeline +from models.enums import UserFrom +from models.model import EndUser +from models.workflow import Workflow, WorkflowType + +logger = logging.getLogger(__name__) + + +class PipelineRunner(WorkflowBasedAppRunner): + """ + Pipeline Application Runner + """ + + def __init__( + self, + application_generate_entity: RagPipelineGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: + """ + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param workflow_thread_pool_id: workflow thread pool id + """ + self.application_generate_entity = application_generate_entity + self.queue_manager = queue_manager + self.workflow_thread_pool_id = workflow_thread_pool_id + + def run(self) -> None: + """ + Run application + """ + app_config = self.application_generate_entity.app_config + app_config = cast(PipelineConfig, app_config) + + user_id = None + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = self.application_generate_entity.user_id + + pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + + workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id) + if not workflow: + raise ValueError("Workflow not initialized") + + db.session.close() + + workflow_callbacks: list[WorkflowCallback] = [] + if dify_config.DEBUG: + workflow_callbacks.append(WorkflowLoggingCallback()) + + # if only single iteration run is requested + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs, + ) + elif self.application_generate_entity.single_loop_run: + # if only single loop run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( + workflow=workflow, + node_id=self.application_generate_entity.single_loop_run.node_id, + user_inputs=self.application_generate_entity.single_loop_run.inputs, + ) + else: + inputs = self.application_generate_entity.inputs + files = self.application_generate_entity.files + + # Create a variable pool. + system_inputs = { + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.APP_ID: app_config.app_id, + SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, + SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, + SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id, + SystemVariableKey.BATCH: self.application_generate_entity.batch, + SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, + } + + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) + + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) + + # RUN WORKFLOW + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, + variable_pool=variable_pool, + thread_pool_id=self.workflow_thread_pool_id, + ) + + generator = workflow_entry.run(callbacks=workflow_callbacks) + + for event in generator: + self._handle_event(workflow_entry, event) + + def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id + ) + .first() + ) + + # return workflow + return workflow diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 56e6b46a60..d730704f48 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -21,6 +21,7 @@ class InvokeFrom(Enum): WEB_APP = "web-app" EXPLORE = "explore" DEBUGGER = "debugger" + PUBLISHED = "published" @classmethod def value_of(cls, value: str): @@ -226,3 +227,37 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): inputs: dict single_loop_run: Optional[SingleLoopRunEntity] = None + + +class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): + """ + RAG Pipeline Application Generate Entity. + """ + + # app config + pipline_config: WorkflowUIBasedAppConfig + datasource_type: str + datasource_info: Mapping[str, Any] + dataset_id: str + batch: str + document_id: str + + class SingleIterationRunEntity(BaseModel): + """ + Single Iteration Run Entity. + """ + + node_id: str + inputs: dict + + single_iteration_run: Optional[SingleIterationRunEntity] = None + + class SingleLoopRunEntity(BaseModel): + """ + Single Loop Run Entity. + """ + + node_id: str + inputs: dict + + single_loop_run: Optional[SingleLoopRunEntity] = None diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 15d9e7d9ba..d8681b6491 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,18 +1,13 @@ -from collections.abc import Mapping -from typing import Any +from abc import ABC, abstractmethod from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, + DatasourceProviderType, ) -from core.plugin.impl.datasource import PluginDatasourceManager -from core.plugin.utils.converter import convert_parameters_to_plugin_format -class DatasourcePlugin: - tenant_id: str - icon: str - plugin_unique_identifier: str +class DatasourcePlugin(ABC): entity: DatasourceEntity runtime: DatasourceRuntime @@ -20,57 +15,19 @@ class DatasourcePlugin: self, entity: DatasourceEntity, runtime: DatasourceRuntime, - tenant_id: str, - icon: str, - plugin_unique_identifier: str, ) -> None: self.entity = entity self.runtime = runtime - self.tenant_id = tenant_id - self.icon = icon - self.plugin_unique_identifier = plugin_unique_identifier - def _invoke_first_step( - self, - user_id: str, - datasource_parameters: dict[str, Any], - ) -> Mapping[str, Any]: - manager = PluginDatasourceManager() - - datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) - - return manager.invoke_first_step( - tenant_id=self.tenant_id, - user_id=user_id, - datasource_provider=self.entity.identity.provider, - datasource_name=self.entity.identity.name, - credentials=self.runtime.credentials, - datasource_parameters=datasource_parameters, - ) - - def _invoke_second_step( - self, - user_id: str, - datasource_parameters: dict[str, Any], - ) -> Mapping[str, Any]: - manager = PluginDatasourceManager() - - datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) - - return manager.invoke_second_step( - tenant_id=self.tenant_id, - user_id=user_id, - datasource_provider=self.entity.identity.provider, - datasource_name=self.entity.identity.name, - credentials=self.runtime.credentials, - datasource_parameters=datasource_parameters, - ) + @abstractmethod + def datasource_provider_type(self) -> DatasourceProviderType: + """ + returns the type of the datasource provider + """ + return DatasourceProviderType.LOCAL_FILE def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": - return DatasourcePlugin( - entity=self.entity, + return self.__class__( + entity=self.entity.model_copy(), runtime=runtime, - tenant_id=self.tenant_id, - icon=self.icon, - plugin_unique_identifier=self.plugin_unique_identifier, ) diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index 13804f53d9..1544270d7a 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -1,26 +1,19 @@ +from abc import ABC, abstractmethod from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType from core.entities.provider_entities import ProviderConfig from core.plugin.impl.tool import PluginToolManager from core.tools.errors import ToolProviderCredentialValidationError -class DatasourcePluginProviderController: +class DatasourcePluginProviderController(ABC): entity: DatasourceProviderEntityWithPlugin - tenant_id: str - plugin_id: str - plugin_unique_identifier: str - def __init__( - self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str - ) -> None: + def __init__(self, entity: DatasourceProviderEntityWithPlugin) -> None: self.entity = entity - self.tenant_id = tenant_id - self.plugin_id = plugin_id - self.plugin_unique_identifier = plugin_unique_identifier @property def need_credentials(self) -> bool: @@ -44,29 +37,19 @@ class DatasourcePluginProviderController: ): raise ToolProviderCredentialValidationError("Invalid credentials") - def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.LOCAL_FILE + + @abstractmethod + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: """ return datasource with given name """ - datasource_entity = next( - ( - datasource_entity - for datasource_entity in self.entity.datasources - if datasource_entity.identity.name == datasource_name - ), - None, - ) - - if not datasource_entity: - raise ValueError(f"Datasource with name {datasource_name} not found") - - return DatasourcePlugin( - entity=datasource_entity, - runtime=DatasourceRuntime(tenant_id=self.tenant_id), - tenant_id=self.tenant_id, - icon=self.entity.identity.icon, - plugin_unique_identifier=self.plugin_unique_identifier, - ) + pass def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore """ diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 8d6bed41fa..3b224c9e64 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -28,13 +28,13 @@ class DatasourceProviderApiEntity(BaseModel): description: I18nObject icon: str | dict label: I18nObject # label - type: ToolProviderType + type: str masked_credentials: Optional[dict] = None original_credentials: Optional[dict] = None is_team_authorization: bool = False allow_delete: bool = True - plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") - plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") + plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource") + plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource") datasources: list[DatasourceApiEntity] = Field(default_factory=list) labels: list[str] = Field(default_factory=list) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 25d7c1c352..7b3fadfee8 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -23,7 +23,7 @@ class DatasourceProviderType(enum.StrEnum): ONLINE_DOCUMENT = "online_document" LOCAL_FILE = "local_file" - WEBSITE = "website" + WEBSITE_CRAWL = "website_crawl" @classmethod def value_of(cls, value: str) -> "DatasourceProviderType": @@ -111,10 +111,10 @@ class DatasourceParameter(PluginParameter): class DatasourceIdentity(BaseModel): - author: str = Field(..., description="The author of the tool") - name: str = Field(..., description="The name of the tool") - label: I18nObject = Field(..., description="The label of the tool") - provider: str = Field(..., description="The provider of the tool") + author: str = Field(..., description="The author of the datasource") + name: str = Field(..., description="The name of the datasource") + label: I18nObject = Field(..., description="The label of the datasource") + provider: str = Field(..., description="The provider of the datasource") icon: Optional[str] = None @@ -145,7 +145,7 @@ class DatasourceProviderEntity(ToolProviderEntity): class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): - datasources: list[DatasourceEntity] = Field(default_factory=list) + datasources: list[DatasourceEntity] = Field(default_factory=list) class DatasourceInvokeMeta(BaseModel): @@ -195,3 +195,105 @@ class DatasourceInvokeFrom(Enum): """ RAG_PIPELINE = "rag_pipeline" + + +class GetOnlineDocumentPagesRequest(BaseModel): + """ + Get online document pages request + """ + + tenant_id: str = Field(..., description="The tenant id") + + +class OnlineDocumentPageIcon(BaseModel): + """ + Online document page icon + """ + + type: str = Field(..., description="The type of the icon") + url: str = Field(..., description="The url of the icon") + + +class OnlineDocumentPage(BaseModel): + """ + Online document page + """ + + page_id: str = Field(..., description="The page id") + page_title: str = Field(..., description="The page title") + page_icon: Optional[OnlineDocumentPageIcon] = Field(None, description="The page icon") + type: str = Field(..., description="The type of the page") + last_edited_time: str = Field(..., description="The last edited time") + + +class OnlineDocumentInfo(BaseModel): + """ + Online document info + """ + + workspace_id: str = Field(..., description="The workspace id") + workspace_name: str = Field(..., description="The workspace name") + workspace_icon: str = Field(..., description="The workspace icon") + total: int = Field(..., description="The total number of documents") + pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") + + +class GetOnlineDocumentPagesResponse(BaseModel): + """ + Get online document pages response + """ + + result: list[OnlineDocumentInfo] + + +class GetOnlineDocumentPageContentRequest(BaseModel): + """ + Get online document page content request + """ + + online_document_info_list: list[OnlineDocumentInfo] + + +class OnlineDocumentPageContent(BaseModel): + """ + Online document page content + """ + + page_id: str = Field(..., description="The page id") + content: str = Field(..., description="The content of the page") + + +class GetOnlineDocumentPageContentResponse(BaseModel): + """ + Get online document page content response + """ + + result: list[OnlineDocumentPageContent] + + +class GetWebsiteCrawlRequest(BaseModel): + """ + Get website crawl request + """ + + url: str = Field(..., description="The url of the website") + crawl_parameters: dict = Field(..., description="The crawl parameters") + + +class WebSiteInfo(BaseModel): + """ + Website info + """ + + source_url: str = Field(..., description="The url of the website") + markdown: str = Field(..., description="The markdown of the website") + title: str = Field(..., description="The title of the website") + description: str = Field(..., description="The description of the website") + + +class GetWebsiteCrawlResponse(BaseModel): + """ + Get website crawl response + """ + + result: list[WebSiteInfo] diff --git a/api/core/datasource/local_file/local_file_plugin.py b/api/core/datasource/local_file/local_file_plugin.py new file mode 100644 index 0000000000..a9dced1186 --- /dev/null +++ b/api/core/datasource/local_file/local_file_plugin.py @@ -0,0 +1,37 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, +) + + +class LocalFileDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def datasource_provider_type(self) -> DatasourceProviderType: + return DatasourceProviderType.LOCAL_FILE + + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + return DatasourcePlugin( + entity=self.entity, + runtime=runtime, + tenant_id=self.tenant_id, + icon=self.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/local_file/local_file_provider.py b/api/core/datasource/local_file/local_file_provider.py new file mode 100644 index 0000000000..79f885dda5 --- /dev/null +++ b/api/core/datasource/local_file/local_file_provider.py @@ -0,0 +1,58 @@ +from typing import Any + +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin + + +class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + tenant_id: str + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity) + self.tenant_id = tenant_id + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.LOCAL_FILE + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + """ + pass + + def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return LocalFileDatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py new file mode 100644 index 0000000000..197d85ef59 --- /dev/null +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -0,0 +1,80 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + GetOnlineDocumentPageContentResponse, + GetOnlineDocumentPagesRequest, + GetOnlineDocumentPagesResponse, +) +from core.plugin.impl.datasource import PluginDatasourceManager + + +class OnlineDocumentDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + entity: DatasourceEntity + runtime: DatasourceRuntime + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def _get_online_document_pages( + self, + user_id: str, + datasource_parameters: GetOnlineDocumentPagesRequest, + provider_type: str, + ) -> GetOnlineDocumentPagesResponse: + manager = PluginDatasourceManager() + + return manager.get_online_document_pages( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def _get_online_document_page_content( + self, + user_id: str, + datasource_parameters: GetOnlineDocumentPageContentRequest, + provider_type: str, + ) -> GetOnlineDocumentPageContentResponse: + manager = PluginDatasourceManager() + + return manager.get_online_document_page_content( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def datasource_provider_type(self) -> DatasourceProviderType: + return DatasourceProviderType.ONLINE_DOCUMENT + + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + return DatasourcePlugin( + entity=self.entity, + runtime=runtime, + tenant_id=self.tenant_id, + icon=self.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/online_document/online_document_provider.py b/api/core/datasource/online_document/online_document_provider.py new file mode 100644 index 0000000000..06572880b8 --- /dev/null +++ b/api/core/datasource/online_document/online_document_provider.py @@ -0,0 +1,50 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType + + +class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + tenant_id: str + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity) + self.tenant_id = tenant_id + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.ONLINE_DOCUMENT + + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return DatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py new file mode 100644 index 0000000000..8454d1636e --- /dev/null +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -0,0 +1,63 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, + GetWebsiteCrawlRequest, + GetWebsiteCrawlResponse, +) +from core.plugin.impl.datasource import PluginDatasourceManager +from core.plugin.utils.converter import convert_parameters_to_plugin_format + + +class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): + tenant_id: str + icon: str + plugin_unique_identifier: str + entity: DatasourceEntity + runtime: DatasourceRuntime + + def __init__( + self, + entity: DatasourceEntity, + runtime: DatasourceRuntime, + tenant_id: str, + icon: str, + plugin_unique_identifier: str, + ) -> None: + super().__init__(entity, runtime) + self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier + + def _get_website_crawl( + self, + user_id: str, + datasource_parameters: GetWebsiteCrawlRequest, + provider_type: str, + ) -> GetWebsiteCrawlResponse: + manager = PluginDatasourceManager() + + datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) + + return manager.invoke_first_step( + tenant_id=self.tenant_id, + user_id=user_id, + datasource_provider=self.entity.identity.provider, + datasource_name=self.entity.identity.name, + credentials=self.runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def datasource_provider_type(self) -> DatasourceProviderType: + return DatasourceProviderType.WEBSITE_CRAWL + + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + return DatasourcePlugin( + entity=self.entity, + runtime=runtime, + tenant_id=self.tenant_id, + icon=self.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py new file mode 100644 index 0000000000..9c6bcdb7c2 --- /dev/null +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -0,0 +1,50 @@ +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType + + +class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController): + entity: DatasourceProviderEntityWithPlugin + tenant_id: str + plugin_id: str + plugin_unique_identifier: str + + def __init__( + self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: + super().__init__(entity) + self.tenant_id = tenant_id + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def provider_type(self) -> DatasourceProviderType: + """ + returns the type of the provider + """ + return DatasourceProviderType.WEBSITE_CRAWL + + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore + """ + return datasource with given name + """ + datasource_entity = next( + ( + datasource_entity + for datasource_entity in self.entity.datasources + if datasource_entity.identity.name == datasource_name + ), + None, + ) + + if not datasource_entity: + raise ValueError(f"Datasource with name {datasource_name} not found") + + return DatasourcePlugin( + entity=datasource_entity, + runtime=DatasourceRuntime(tenant_id=self.tenant_id), + tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, + ) diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 90086173fa..3b0defbb08 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -52,6 +52,7 @@ class PluginDatasourceProviderEntity(BaseModel): provider: str plugin_unique_identifier: str plugin_id: str + author: str declaration: DatasourceProviderEntityWithPlugin diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 922e65d725..ebe08bd7eb 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,6 +1,14 @@ -from collections.abc import Mapping from typing import Any +from core.datasource.entities.api_entities import DatasourceProviderApiEntity +from core.datasource.entities.datasource_entities import ( + GetOnlineDocumentPageContentRequest, + GetOnlineDocumentPageContentResponse, + GetOnlineDocumentPagesRequest, + GetOnlineDocumentPagesResponse, + GetWebsiteCrawlRequest, + GetWebsiteCrawlResponse, +) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -10,7 +18,7 @@ from core.plugin.impl.base import BasePluginClient class PluginDatasourceManager(BasePluginClient): - def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]: + def fetch_datasource_providers(self, tenant_id: str) -> list[DatasourceProviderApiEntity]: """ Fetch datasource providers for the given tenant. """ @@ -19,27 +27,27 @@ class PluginDatasourceManager(BasePluginClient): for provider in json_response.get("data", []): declaration = provider.get("declaration", {}) or {} provider_name = declaration.get("identity", {}).get("name") - for tool in declaration.get("tools", []): - tool["identity"]["provider"] = provider_name + for datasource in declaration.get("datasources", []): + datasource["identity"]["provider"] = provider_name return json_response - response = self._request_with_plugin_daemon_response( - "GET", - f"plugin/{tenant_id}/management/datasources", - list[PluginDatasourceProviderEntity], - params={"page": 1, "page_size": 256}, - transformer=transformer, - ) + # response = self._request_with_plugin_daemon_response( + # "GET", + # f"plugin/{tenant_id}/management/datasources", + # list[PluginDatasourceProviderEntity], + # params={"page": 1, "page_size": 256}, + # transformer=transformer, + # ) - for provider in response: - provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + # for provider in response: + # provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" - # override the provider name for each tool to plugin_id/provider_name - for datasource in provider.declaration.datasources: - datasource.identity.provider = provider.declaration.identity.name + # # override the provider name for each tool to plugin_id/provider_name + # for datasource in provider.declaration.datasources: + # datasource.identity.provider = provider.declaration.identity.name - return response + return [DatasourceProviderApiEntity(**self._get_local_file_datasource_provider())] def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: """ @@ -71,15 +79,16 @@ class PluginDatasourceManager(BasePluginClient): return response - def invoke_first_step( + def get_website_crawl( self, tenant_id: str, user_id: str, datasource_provider: str, datasource_name: str, credentials: dict[str, Any], - datasource_parameters: dict[str, Any], - ) -> Mapping[str, Any]: + datasource_parameters: GetWebsiteCrawlRequest, + provider_type: str, + ) -> GetWebsiteCrawlResponse: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -88,8 +97,8 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/first_step", - dict, + f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_website_crawl", + GetWebsiteCrawlResponse, data={ "user_id": user_id, "data": { @@ -109,15 +118,16 @@ class PluginDatasourceManager(BasePluginClient): raise Exception("No response from plugin daemon") - def invoke_second_step( + def get_online_document_pages( self, tenant_id: str, user_id: str, datasource_provider: str, datasource_name: str, credentials: dict[str, Any], - datasource_parameters: dict[str, Any], - ) -> Mapping[str, Any]: + datasource_parameters: GetOnlineDocumentPagesRequest, + provider_type: str, + ) -> GetOnlineDocumentPagesResponse: """ Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. """ @@ -126,8 +136,47 @@ class PluginDatasourceManager(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/second_step", - dict, + f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_pages", + GetOnlineDocumentPagesResponse, + data={ + "user_id": user_id, + "data": { + "provider": datasource_provider_id.provider_name, + "datasource": datasource_name, + "credentials": credentials, + "datasource_parameters": datasource_parameters, + }, + }, + headers={ + "X-Plugin-ID": datasource_provider_id.plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + + raise Exception("No response from plugin daemon") + + def get_online_document_page_content( + self, + tenant_id: str, + user_id: str, + datasource_provider: str, + datasource_name: str, + credentials: dict[str, Any], + datasource_parameters: GetOnlineDocumentPageContentRequest, + provider_type: str, + ) -> GetOnlineDocumentPageContentResponse: + """ + Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. + """ + + datasource_provider_id = GenericProviderID(datasource_provider) + + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/datasource/{provider_type}/get_online_document_page_content", + GetOnlineDocumentPageContentResponse, data={ "user_id": user_id, "data": { @@ -176,3 +225,53 @@ class PluginDatasourceManager(BasePluginClient): return resp.result return False + + def _get_local_file_datasource_provider(self) -> dict[str, Any]: + return { + "id": "langgenius/file/file", + "author": "langgenius", + "name": "langgenius/file/file", + "plugin_id": "langgenius/file", + "plugin_unique_identifier": "langgenius/file:0.0.1@dify", + "description": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + }, + "icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", + "label": { + "zh_Hans": "File", + "en_US": "File", + "pt_BR": "File", + "ja_JP": "File" + }, + "type": "datasource", + "team_credentials": {}, + "is_team_authorization": False, + "allow_delete": True, + "datasources": [{ + "author": "langgenius", + "name": "upload_file", + "label": { + "en_US": "File", + "zh_Hans": "File", + "pt_BR": "File", + "ja_JP": "File" + }, + "description": { + "en_US": "File", + "zh_Hans": "File", + "pt_BR": "File", + "ja_JP": "File." + }, + "parameters": [], + "labels": [ + "search" + ], + "output_schema": None + }], + "labels": [ + "search" + ] + } diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 9642efa1a5..34d17c880a 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -14,3 +14,7 @@ class SystemVariableKey(StrEnum): APP_ID = "app_id" WORKFLOW_ID = "workflow_id" WORKFLOW_RUN_ID = "workflow_run_id" + # RAG Pipeline + DOCUMENT_ID = "document_id" + BATCH = "batch" + DATASET_ID = "dataset_id" diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index e7d4da8426..d25784b781 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -3,7 +3,11 @@ from typing import Any, cast from core.datasource.entities.datasource_entities import ( DatasourceParameter, + DatasourceProviderType, + GetWebsiteCrawlResponse, ) +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.file import File from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment @@ -77,15 +81,44 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): for_log=True, ) - # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - try: # TODO: handle result - result = datasource_runtime._invoke_second_step( - user_id=self.user_id, - datasource_parameters=parameters, - ) + if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + result = datasource_runtime._get_online_document_page_content( + user_id=self.user_id, + datasource_parameters=parameters, + provider_type=node_data.provider_type, + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "result": result.result.model_dump(), + "datasource_type": datasource_runtime.datasource_provider_type, + }, + ) + elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( + user_id=self.user_id, + datasource_parameters=parameters, + provider_type=node_data.provider_type, + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "result": result.result.model_dump(), + "datasource_type": datasource_runtime.datasource_provider_type, + }, + ) + else: + raise DatasourceNodeError( + f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}" + ) except PluginDaemonClientSideError as e: yield RunCompletedEvent( run_result=NodeRunResult( diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 6b2c91a8a0..0d0da757d5 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -155,9 +155,4 @@ class KnowledgeIndexNodeData(BaseNodeData): """ type: str = "knowledge-index" - dataset_id: str - document_id: str index_chunk_variable_selector: list[str] - chunk_structure: Literal["general", "parent-child"] - index_method: IndexMethod - retrieval_setting: RetrievalSetting diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 1fa6c20bf9..dac541621a 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,25 +1,19 @@ import datetime import logging -import time from collections.abc import Mapping from typing import Any, cast -from flask_login import current_user - -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables.segments import ObjectSegment from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.enums import NodeType from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.dataset import Dataset, Document, RateLimitLog +from models.dataset import Dataset, Document from models.workflow import WorkflowNodeExecutionStatus -from services.dataset_service import DatasetCollectionBindingService -from services.feature_service import FeatureService from .entities import KnowledgeIndexNodeData from .exc import ( @@ -43,8 +37,9 @@ class KnowledgeIndexNode(LLMNode): def _run(self) -> NodeRunResult: # type: ignore node_data = cast(KnowledgeIndexNodeData, self.node_data) + variable_pool = self.graph_runtime_state.variable_pool # extract variables - variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector) + variable = variable_pool.get(node_data.index_chunk_variable_selector) if not isinstance(variable, ObjectSegment): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -57,34 +52,9 @@ class KnowledgeIndexNode(LLMNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." ) - # check rate limit - if self.tenant_id: - knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) - if knowledge_rate_limit.enabled: - current_time = int(time.time() * 1000) - key = f"rate_limit_{self.tenant_id}" - redis_client.zadd(key, {current_time: current_time}) - redis_client.zremrangebyscore(key, 0, current_time - 60000) - request_count = redis_client.zcard(key) - if request_count > knowledge_rate_limit.limit: - # add ratelimit record - rate_limit_log = RateLimitLog( - tenant_id=self.tenant_id, - subscription_plan=knowledge_rate_limit.subscription_plan, - operation="knowledge", - ) - db.session.add(rate_limit_log) - db.session.commit() - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error="Sorry, you have reached the knowledge base request rate limit of your subscription.", - error_type="RateLimitExceeded", - ) - # retrieve knowledge try: - results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks) + results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks, variable_pool=variable_pool) outputs = {"result": results} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs @@ -107,54 +77,26 @@ class KnowledgeIndexNode(LLMNode): error_type=type(e).__name__, ) - def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any]) -> Any: - dataset = Dataset.query.filter_by(id=node_data.dataset_id).first() + def _invoke_knowledge_index( + self, node_data: KnowledgeIndexNodeData, chunks: Mapping[str, Any], variable_pool: VariablePool + ) -> Any: + dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + if not dataset_id: + raise KnowledgeIndexNodeError("Dataset ID is required.") + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if not document_id: + raise KnowledgeIndexNodeError("Document ID is required.") + batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) + if not batch: + raise KnowledgeIndexNodeError("Batch is required.") + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.") + raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") - document = Document.query.filter_by(id=node_data.document_id).first() + document = Document.query.filter_by(id=document_id).first() if not document: - raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.") - - retrieval_setting = node_data.retrieval_setting - index_method = node_data.index_method - if not dataset.indexing_technique: - if node_data.index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: - raise ValueError("Indexing technique is invalid") - - dataset.indexing_technique = index_method.indexing_technique - if index_method.indexing_technique == "high_quality": - model_manager = ModelManager() - if ( - index_method.embedding_setting.embedding_model - and index_method.embedding_setting.embedding_model_provider - ): - dataset_embedding_model = index_method.embedding_setting.embedding_model - dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider - else: - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - dataset_embedding_model = embedding_model.model - dataset_embedding_model_provider = embedding_model.provider - dataset.embedding_model = dataset_embedding_model - dataset.embedding_model_provider = dataset_embedding_model_provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - dataset_embedding_model_provider, dataset_embedding_model - ) - dataset.collection_binding_id = dataset_collection_binding.id - if not dataset.retrieval_model: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } + raise KnowledgeIndexNodeError(f"Document {document_id} not found.") - dataset.retrieval_model = ( - retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model - ) # type: ignore index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor() index_processor.index(dataset, document, chunks) @@ -166,6 +108,7 @@ class KnowledgeIndexNode(LLMNode): return { "dataset_id": dataset.id, "dataset_name": dataset.name, + "batch": batch, "document_id": document.id, "document_name": document.name, "created_at": document.created_at, diff --git a/api/core/workflow/nodes/knowledge_index/template_prompts.py b/api/core/workflow/nodes/knowledge_index/template_prompts.py deleted file mode 100644 index 7abd55d798..0000000000 --- a/api/core/workflow/nodes/knowledge_index/template_prompts.py +++ /dev/null @@ -1,66 +0,0 @@ -METADATA_FILTER_SYSTEM_PROMPT = """ - ### Job Description', - You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value - ### Task - Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". - ### Format - The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. - ### Constraint - DO NOT include anything other than the JSON array in your response. -""" # noqa: E501 - -METADATA_FILTER_USER_PROMPT_1 = """ - { "input_text": "I want to know which company’s email address test@example.com is?", - "metadata_fields": ["filename", "email", "phone", "address"] - } -""" - -METADATA_FILTER_ASSISTANT_PROMPT_1 = """ -```json - {"metadata_map": [ - {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="} - ] - } -``` -""" - -METADATA_FILTER_USER_PROMPT_2 = """ - {"input_text": "What are the movies with a score of more than 9 in 2024?", - "metadata_fields": ["name", "year", "rating", "country"]} -""" - -METADATA_FILTER_ASSISTANT_PROMPT_2 = """ -```json - {"metadata_map": [ - {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, - {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}, - ]} -``` -""" - -METADATA_FILTER_USER_PROMPT_3 = """ - '{{"input_text": "{input_text}",', - '"metadata_fields": {metadata_fields}}}' -""" - -METADATA_FILTER_COMPLETION_PROMPT = """ -### Job Description -You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value -### Task -# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". -### Format -The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Example -Here is the chat example between human and assistant, inside XML tags. - -User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}} -Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}} -User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}} -Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}} - -### User Input -{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}} -### Assistant Output -""" # noqa: E501 diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 8c702b74ee..00448d2a9b 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -57,8 +57,6 @@ class MultipleRetrievalConfig(BaseModel): class ModelConfig(BaseModel): - - provider: str name: str mode: str diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 69a786e2f5..d829d57812 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -39,7 +39,6 @@ from core.variables.variables import ( from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, - PIPELINE_VARIABLE_NODE_ID, ) @@ -123,6 +122,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = result.model_copy(update={"selector": selector}) return cast(Variable, result) + def build_segment(value: Any, /) -> Segment: if value is None: return NoneSegment() diff --git a/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py b/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py new file mode 100644 index 0000000000..18e90e49dc --- /dev/null +++ b/api/migrations/versions/2025_05_16_1659-abb18a379e62_add_pipeline_info_2.py @@ -0,0 +1,113 @@ +"""add_pipeline_info_2 + +Revision ID: abb18a379e62 +Revises: b35c3db83d09 +Create Date: 2025-05-16 16:59:16.423127 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'abb18a379e62' +down_revision = 'b35c3db83d09' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('component_failure_stats') + op.drop_table('reliability_data') + op.drop_table('maintenance') + op.drop_table('operational_data') + op.drop_table('component_failure') + op.drop_table('tool_providers') + op.drop_table('safety_data') + op.drop_table('incident_data') + with op.batch_alter_table('pipelines', schema=None) as batch_op: + batch_op.drop_column('mode') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipelines', schema=None) as batch_op: + batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False)) + + op.create_table('incident_data', + sa.Column('IncidentID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('IncidentDescription', sa.TEXT(), autoincrement=False, nullable=False), + sa.Column('IncidentDate', sa.DATE(), autoincrement=False, nullable=False), + sa.Column('Consequences', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('ResponseActions', sa.TEXT(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('IncidentID', name='incident_data_pkey') + ) + op.create_table('safety_data', + sa.Column('SafetyID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('SafetyInspectionDate', sa.DATE(), autoincrement=False, nullable=False), + sa.Column('SafetyFindings', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('SafetyIncidentDescription', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('ComplianceStatus', sa.VARCHAR(length=50), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('SafetyID', name='safety_data_pkey') + ) + op.create_table('tool_providers', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), + sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False), + sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + ) + op.create_table('component_failure', + sa.Column('FailureID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('Date', sa.DATE(), autoincrement=False, nullable=False), + sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('RepairAction', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('FailureID', name='component_failure_pkey'), + sa.UniqueConstraint('Date', 'Component', 'FailureMode', 'Cause', 'Technician', name='unique_failure_entry') + ) + op.create_table('operational_data', + sa.Column('OperationID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('CraneUsage', sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column('LoadWeight', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), + sa.Column('LoadFrequency', sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column('EnvironmentalConditions', sa.TEXT(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint('OperationID', name='operational_data_pkey') + ) + op.create_table('maintenance', + sa.Column('MaintenanceID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('MaintenanceType', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('MaintenanceDate', sa.DATE(), autoincrement=False, nullable=False), + sa.Column('ServiceDescription', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('PartsReplaced', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('MaintenanceID', name='maintenance_pkey') + ) + op.create_table('reliability_data', + sa.Column('ComponentID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('ComponentName', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), + sa.Column('FailureRate', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('ComponentID', name='reliability_data_pkey') + ) + op.create_table('component_failure_stats', + sa.Column('StatID', sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False), + sa.Column('PossibleAction', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('Probability', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), + sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('StatID', name='component_failure_stats_pkey') + ) + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 0ed59c898f..22703771d5 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1170,6 +1170,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] def pipeline(self): return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() + class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_customized_templates" __table_args__ = ( @@ -1205,6 +1206,7 @@ class Pipeline(Base): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + @property def dataset(self): return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() diff --git a/api/models/model.py b/api/models/model.py index ee79fbd6b5..e088c2e537 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -52,6 +52,7 @@ class AppMode(StrEnum): ADVANCED_CHAT = "advanced-chat" AGENT_CHAT = "agent-chat" CHANNEL = "channel" + RAG_PIPELINE = "rag-pipeline" @classmethod def value_of(cls, value: str) -> "AppMode": diff --git a/api/models/workflow.py b/api/models/workflow.py index d5cf71841e..038648fc8e 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -3,7 +3,7 @@ import logging from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, List, Optional, Self, Union +from typing import TYPE_CHECKING, Any, Optional, Self, Union from uuid import uuid4 from core.variables import utils as variable_utils @@ -43,7 +43,7 @@ class WorkflowType(Enum): WORKFLOW = "workflow" CHAT = "chat" - RAG_PIPELINE = "rag_pipeline" + RAG_PIPELINE = "rag-pipeline" @classmethod def value_of(cls, value: str) -> "WorkflowType": @@ -370,7 +370,7 @@ class Workflow(Base): return results @rag_pipeline_variables.setter - def rag_pipeline_variables(self, values: List[dict]) -> None: + def rag_pipeline_variables(self, values: list[dict]) -> None: self._rag_pipeline_variables = json.dumps( {item["variable"]: item for item in values}, ensure_ascii=False, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 81db03033f..8a87964276 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1550,7 +1550,7 @@ class DocumentService: @staticmethod def build_document( dataset: Dataset, - process_rule_id: str, + process_rule_id: str | None, data_source_type: str, document_form: str, document_language: str, diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py new file mode 100644 index 0000000000..089519dd0d --- /dev/null +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -0,0 +1,109 @@ +from collections.abc import Mapping +from typing import Any, Union + +from configs import dify_config +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.pipeline.pipeline_generator import PipelineGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from models.dataset import Pipeline +from models.model import Account, App, AppMode, EndUser +from models.workflow import Workflow +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class PipelineGenerateService: + @classmethod + def generate( + cls, + pipeline: Pipeline, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): + """ + Pipeline Content Generate + :param pipeline: pipeline + :param user: user + :param args: args + :param invoke_from: invoke from + :param streaming: streaming + :return: + """ + try: + workflow = cls._get_workflow(pipeline, invoke_from) + return PipelineGenerator.convert_to_event_stream( + PipelineGenerator().generate( + pipeline=pipeline, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + call_depth=0, + workflow_thread_pool_id=None, + ), + ) + + except Exception: + raise + + @staticmethod + def _get_max_active_requests(app_model: App) -> int: + max_active_requests = app_model.max_active_requests + if max_active_requests is None: + max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) + return max_active_requests + + @classmethod + def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): + if app_model.mode == AppMode.ADVANCED_CHAT.value: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + AdvancedChatAppGenerator().single_iteration_generate( + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + ) + ) + elif app_model.mode == AppMode.WORKFLOW.value: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().single_iteration_generate( + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + ) + ) + else: + raise ValueError(f"Invalid app mode {app_model.mode}") + + @classmethod + def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True): + workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER) + return WorkflowAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().single_loop_generate( + app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + ) + ) + + @classmethod + def _get_workflow(cls, pipeline: Pipeline, invoke_from: InvokeFrom) -> Workflow: + """ + Get workflow + :param pipeline: pipeline + :param invoke_from: invoke from + :return: + """ + rag_pipeline_service = RagPipelineService() + if invoke_from == InvokeFrom.DEBUGGER: + # fetch draft workflow by app_model + workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) + + if not workflow: + raise ValueError("Workflow not initialized") + else: + # fetch published workflow by app_model + workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline) + + if not workflow: + raise ValueError("Workflow not published") + + return workflow diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index bda29c804c..11071d82e7 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -29,32 +29,31 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - - pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter( - PipelineBuiltInTemplate.language == language - ).all() + + pipeline_built_in_templates: list[PipelineBuiltInTemplate] = ( + db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all() + ) recommended_pipelines_results = [] for pipeline_built_in_template in pipeline_built_in_templates: pipeline_model: Pipeline = pipeline_built_in_template.pipeline recommended_pipeline_result = { - 'id': pipeline_built_in_template.id, - 'name': pipeline_built_in_template.name, - 'pipeline_id': pipeline_model.id, - 'description': pipeline_built_in_template.description, - 'icon': pipeline_built_in_template.icon, - 'copyright': pipeline_built_in_template.copyright, - 'privacy_policy': pipeline_built_in_template.privacy_policy, - 'position': pipeline_built_in_template.position, + "id": pipeline_built_in_template.id, + "name": pipeline_built_in_template.name, + "pipeline_id": pipeline_model.id, + "description": pipeline_built_in_template.description, + "icon": pipeline_built_in_template.icon, + "copyright": pipeline_built_in_template.copyright, + "privacy_policy": pipeline_built_in_template.privacy_policy, + "position": pipeline_built_in_template.position, } dataset: Dataset = pipeline_model.dataset if dataset: - recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure + recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure recommended_pipelines_results.append(recommended_pipeline_result) - return {'pipeline_templates': recommended_pipelines_results} - + return {"pipeline_templates": recommended_pipelines_results} @classmethod def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]: @@ -64,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :return: """ from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService + # is in public recommended list pipeline_template = ( db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index a7ad3109c3..a0a890aee7 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -3,7 +3,7 @@ import threading import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime -from typing import Any, Literal, Optional +from typing import Any, Optional from uuid import uuid4 from flask_login import current_user @@ -46,7 +46,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi class RagPipelineService: @staticmethod def get_pipeline_templates( - type: Literal["built-in", "customized"] = "built-in", language: str = "en-US" + type: str = "built-in", language: str = "en-US" ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE @@ -358,11 +358,11 @@ class RagPipelineService: return workflow_node_execution - def run_datasource_workflow_node( + def run_published_workflow_node( self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account ) -> WorkflowNodeExecution: """ - Run published workflow datasource + Run published workflow node """ # fetch published workflow by app_model published_workflow = self.get_published_workflow(pipeline=pipeline) @@ -393,6 +393,41 @@ class RagPipelineService: return workflow_node_execution + def run_datasource_workflow_node( + self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: + """ + Run published workflow datasource + """ + # fetch published workflow by app_model + published_workflow = self.get_published_workflow(pipeline=pipeline) + if not published_workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + start_at = time.perf_counter() + + datasource_node_data = published_workflow.graph_dict.get("nodes", {}).get(node_id, {}).get("data", {}) + if not datasource_node_data: + raise ValueError("Datasource node data not found") + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=datasource_node_data.get("provider_id"), + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + ) + result = datasource_runtime._invoke_first_step( + inputs=user_inputs, + provider_type=datasource_node_data.get("provider_type"), + user_id=account.id, + ) + + return { + "result": result, + "provider_type": datasource_node_data.get("provider_type"), + } + def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: @@ -552,7 +587,7 @@ class RagPipelineService: return workflow - def get_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: + def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: """ Get second step parameters of rag pipeline """ @@ -567,9 +602,33 @@ class RagPipelineService: return {} # get datasource provider - datasource_provider_variables = [item for item in rag_pipeline_variables - if item.get("belong_to_node_id") == node_id - or item.get("belong_to_node_id") == "shared"] + datasource_provider_variables = [ + item + for item in rag_pipeline_variables + if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" + ] + return datasource_provider_variables + + def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: + """ + Get second step parameters of rag pipeline + """ + + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # get second step node + rag_pipeline_variables = workflow.rag_pipeline_variables + if not rag_pipeline_variables: + return {} + + # get datasource provider + datasource_provider_variables = [ + item + for item in rag_pipeline_variables + if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" + ] return datasource_provider_variables def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: