feat/datasource
parent
9bafd3a226
commit
b82b26bba5
@ -0,0 +1,95 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
|
from core.app.entities.task_entities import (
|
||||||
|
AppStreamResponse,
|
||||||
|
ErrorStreamResponse,
|
||||||
|
NodeFinishStreamResponse,
|
||||||
|
NodeStartStreamResponse,
|
||||||
|
PingStreamResponse,
|
||||||
|
WorkflowAppBlockingResponse,
|
||||||
|
WorkflowAppStreamResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
|
_blocking_response_type = WorkflowAppBlockingResponse
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
Convert blocking full response.
|
||||||
|
:param blocking_response: blocking response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return dict(blocking_response.to_dict())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
Convert blocking simple response.
|
||||||
|
:param blocking_response: blocking response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return cls.convert_blocking_full_response(blocking_response)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_stream_full_response(
|
||||||
|
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||||
|
) -> Generator[dict | str, None, None]:
|
||||||
|
"""
|
||||||
|
Convert stream full response.
|
||||||
|
:param stream_response: stream response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for chunk in stream_response:
|
||||||
|
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||||
|
sub_stream_response = chunk.stream_response
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, PingStreamResponse):
|
||||||
|
yield "ping"
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_chunk = {
|
||||||
|
"event": sub_stream_response.event.value,
|
||||||
|
"workflow_run_id": chunk.workflow_run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||||
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
|
response_chunk.update(data)
|
||||||
|
else:
|
||||||
|
response_chunk.update(sub_stream_response.to_dict())
|
||||||
|
yield response_chunk
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_stream_simple_response(
|
||||||
|
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||||
|
) -> Generator[dict | str, None, None]:
|
||||||
|
"""
|
||||||
|
Convert stream simple response.
|
||||||
|
:param stream_response: stream response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for chunk in stream_response:
|
||||||
|
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||||
|
sub_stream_response = chunk.stream_response
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, PingStreamResponse):
|
||||||
|
yield "ping"
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_chunk = {
|
||||||
|
"event": sub_stream_response.event.value,
|
||||||
|
"workflow_run_id": chunk.workflow_run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||||
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
|
response_chunk.update(data)
|
||||||
|
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||||
|
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||||
|
else:
|
||||||
|
response_chunk.update(sub_stream_response.to_dict())
|
||||||
|
yield response_chunk
|
||||||
@ -0,0 +1,63 @@
|
|||||||
|
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||||
|
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||||
|
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||||
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
|
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||||
|
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from models.model import AppMode
|
||||||
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineConfig(WorkflowUIBasedAppConfig):
|
||||||
|
"""
|
||||||
|
Pipeline Config Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineConfigManager(BaseAppConfigManager):
|
||||||
|
@classmethod
|
||||||
|
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig:
|
||||||
|
pipeline_config = PipelineConfig(
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
app_id=pipeline.id,
|
||||||
|
app_mode=AppMode.RAG_PIPELINE,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipeline_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
Validate for pipeline config
|
||||||
|
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param config: app model config args
|
||||||
|
:param only_structure_validate: only validate the structure of the config
|
||||||
|
"""
|
||||||
|
related_config_keys = []
|
||||||
|
|
||||||
|
# file upload validation
|
||||||
|
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# text_to_speech
|
||||||
|
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# moderation validation
|
||||||
|
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||||
|
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||||
|
)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
related_config_keys = list(set(related_config_keys))
|
||||||
|
|
||||||
|
# Filter out extra parameters
|
||||||
|
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||||
|
|
||||||
|
return filtered_config
|
||||||
@ -0,0 +1,496 @@
|
|||||||
|
import contextvars
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any, Literal, Optional, Union, overload
|
||||||
|
|
||||||
|
from flask import Flask, current_app
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from configs import dify_config
|
||||||
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
|
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||||
|
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||||
|
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||||
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
|
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity
|
||||||
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||||
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
from models.dataset import Document, Pipeline
|
||||||
|
from services.dataset_service import DocumentService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineGenerator(BaseAppGenerator):
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: Literal[True],
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
) -> Generator[Mapping | str, None, None]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: Literal[False],
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
) -> Mapping[str, Any]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: bool,
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: bool = True,
|
||||||
|
call_depth: int = 0,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||||
|
# convert to app config
|
||||||
|
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow=workflow,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs: Mapping[str, Any] = args["inputs"]
|
||||||
|
datasource_type: str = args["datasource_type"]
|
||||||
|
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
||||||
|
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||||
|
|
||||||
|
for datasource_info in datasource_info_list:
|
||||||
|
workflow_run_id = str(uuid.uuid4())
|
||||||
|
document_id = None
|
||||||
|
if invoke_from == InvokeFrom.PUBLISHED:
|
||||||
|
position = DocumentService.get_documents_position(pipeline.dataset_id)
|
||||||
|
document = self._build_document(
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
dataset_id=pipeline.dataset_id,
|
||||||
|
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
created_from="rag-pipeline",
|
||||||
|
position=position,
|
||||||
|
account=user,
|
||||||
|
batch=batch,
|
||||||
|
document_form=pipeline.dataset.doc_form,
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
document_id = document.id
|
||||||
|
# init application generate entity
|
||||||
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
pipline_config=pipeline_config,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
dataset_id=pipeline.dataset_id,
|
||||||
|
batch=batch,
|
||||||
|
document_id=document_id,
|
||||||
|
inputs=self._prepare_user_inputs(
|
||||||
|
user_inputs=inputs,
|
||||||
|
variables=pipeline_config.variables,
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
|
),
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
call_depth=call_depth,
|
||||||
|
workflow_run_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._generate(
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow=workflow,
|
||||||
|
user=user,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
|
streaming: bool = True,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param app_model: App
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param user: account or end user
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param invoke_from: invoke from source
|
||||||
|
:param workflow_node_execution_repository: repository for workflow node execution
|
||||||
|
:param streaming: is stream
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
"""
|
||||||
|
# init queue manager
|
||||||
|
queue_manager = PipelineQueueManager(
|
||||||
|
task_id=application_generate_entity.task_id,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
|
app_mode=pipeline.mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# new thread
|
||||||
|
worker_thread = threading.Thread(
|
||||||
|
target=self._generate_worker,
|
||||||
|
kwargs={
|
||||||
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
|
"application_generate_entity": application_generate_entity,
|
||||||
|
"queue_manager": queue_manager,
|
||||||
|
"context": contextvars.copy_context(),
|
||||||
|
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_thread.start()
|
||||||
|
|
||||||
|
# return response or stream generator
|
||||||
|
response = self._handle_response(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow=workflow,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
user=user,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
stream=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||||
|
|
||||||
|
def single_iteration_generate(
|
||||||
|
self,
|
||||||
|
app_model: App,
|
||||||
|
workflow: Workflow,
|
||||||
|
node_id: str,
|
||||||
|
user: Account | EndUser,
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
streaming: bool = True,
|
||||||
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param app_model: App
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param node_id: the node id
|
||||||
|
:param user: account or end user
|
||||||
|
:param args: request args
|
||||||
|
:param streaming: is streamed
|
||||||
|
"""
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("node_id is required")
|
||||||
|
|
||||||
|
if args.get("inputs") is None:
|
||||||
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
|
# convert to app config
|
||||||
|
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||||
|
|
||||||
|
# init application generate entity
|
||||||
|
application_generate_entity = WorkflowAppGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=app_config,
|
||||||
|
inputs={},
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
extras={"auto_generate_conversation_name": False},
|
||||||
|
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||||
|
node_id=node_id, inputs=args["inputs"]
|
||||||
|
),
|
||||||
|
workflow_run_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._generate(
|
||||||
|
app_model=app_model,
|
||||||
|
workflow=workflow,
|
||||||
|
user=user,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
def single_loop_generate(
|
||||||
|
self,
|
||||||
|
app_model: App,
|
||||||
|
workflow: Workflow,
|
||||||
|
node_id: str,
|
||||||
|
user: Account | EndUser,
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
streaming: bool = True,
|
||||||
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param app_model: App
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param node_id: the node id
|
||||||
|
:param user: account or end user
|
||||||
|
:param args: request args
|
||||||
|
:param streaming: is streamed
|
||||||
|
"""
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("node_id is required")
|
||||||
|
|
||||||
|
if args.get("inputs") is None:
|
||||||
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
|
# convert to app config
|
||||||
|
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||||
|
|
||||||
|
# init application generate entity
|
||||||
|
application_generate_entity = WorkflowAppGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=app_config,
|
||||||
|
inputs={},
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
extras={"auto_generate_conversation_name": False},
|
||||||
|
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||||
|
workflow_run_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._generate(
|
||||||
|
app_model=app_model,
|
||||||
|
workflow=workflow,
|
||||||
|
user=user,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_worker(
|
||||||
|
self,
|
||||||
|
flask_app: Flask,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
context: contextvars.Context,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generate worker in a new thread.
|
||||||
|
:param flask_app: Flask app
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for var, val in context.items():
|
||||||
|
var.set(val)
|
||||||
|
with flask_app.app_context():
|
||||||
|
try:
|
||||||
|
# workflow app
|
||||||
|
runner = PipelineRunner(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner.run()
|
||||||
|
except GenerateTaskStoppedError:
|
||||||
|
pass
|
||||||
|
except InvokeAuthorizationError:
|
||||||
|
queue_manager.publish_error(
|
||||||
|
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.exception("Validation Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
except ValueError as e:
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
logger.exception("Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Unknown Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
finally:
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
def _handle_response(
|
||||||
|
self,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
workflow: Workflow,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||||
|
"""
|
||||||
|
Handle response.
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param workflow: workflow
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param user: account or end user
|
||||||
|
:param stream: is stream
|
||||||
|
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# init generate task pipeline
|
||||||
|
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow=workflow,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
user=user,
|
||||||
|
stream=stream,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return generate_task_pipeline.process()
|
||||||
|
except ValueError as e:
|
||||||
|
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
|
raise GenerateTaskStoppedError()
|
||||||
|
else:
|
||||||
|
logger.exception(
|
||||||
|
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _build_document(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
built_in_field_enabled: bool,
|
||||||
|
datasource_type: str,
|
||||||
|
datasource_info: Mapping[str, Any],
|
||||||
|
created_from: str,
|
||||||
|
position: int,
|
||||||
|
account: Account,
|
||||||
|
batch: str,
|
||||||
|
document_form: str,
|
||||||
|
):
|
||||||
|
if datasource_type == "local_file":
|
||||||
|
name = datasource_info["name"]
|
||||||
|
elif datasource_type == "online_document":
|
||||||
|
name = datasource_info["page_title"]
|
||||||
|
elif datasource_type == "website_crawl":
|
||||||
|
name = datasource_info["title"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||||
|
|
||||||
|
document = Document(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
position=position,
|
||||||
|
data_source_type=datasource_type,
|
||||||
|
data_source_info=json.dumps(datasource_info),
|
||||||
|
batch=batch,
|
||||||
|
name=name,
|
||||||
|
created_from=created_from,
|
||||||
|
created_by=account.id,
|
||||||
|
doc_form=document_form,
|
||||||
|
)
|
||||||
|
doc_metadata = {}
|
||||||
|
if built_in_field_enabled:
|
||||||
|
doc_metadata = {
|
||||||
|
BuiltInField.document_name: name,
|
||||||
|
BuiltInField.uploader: account.name,
|
||||||
|
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
BuiltInField.source: datasource_type,
|
||||||
|
}
|
||||||
|
if doc_metadata:
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
return document
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.app.entities.queue_entities import (
|
||||||
|
AppQueueEvent,
|
||||||
|
QueueErrorEvent,
|
||||||
|
QueueMessageEndEvent,
|
||||||
|
QueueStopEvent,
|
||||||
|
QueueWorkflowFailedEvent,
|
||||||
|
QueueWorkflowPartialSuccessEvent,
|
||||||
|
QueueWorkflowSucceededEvent,
|
||||||
|
WorkflowQueueMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineQueueManager(AppQueueManager):
|
||||||
|
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||||
|
super().__init__(task_id, user_id, invoke_from)
|
||||||
|
|
||||||
|
self._app_mode = app_mode
|
||||||
|
|
||||||
|
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||||
|
"""
|
||||||
|
Publish event to queue
|
||||||
|
:param event:
|
||||||
|
:param pub_from:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||||
|
|
||||||
|
self._q.put(message)
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
event,
|
||||||
|
QueueStopEvent
|
||||||
|
| QueueErrorEvent
|
||||||
|
| QueueMessageEndEvent
|
||||||
|
| QueueWorkflowSucceededEvent
|
||||||
|
| QueueWorkflowFailedEvent
|
||||||
|
| QueueWorkflowPartialSuccessEvent,
|
||||||
|
):
|
||||||
|
self.stop_listen()
|
||||||
|
|
||||||
|
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||||
|
raise GenerateTaskStoppedError()
|
||||||
@ -0,0 +1,154 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
|
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||||
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
InvokeFrom,
|
||||||
|
RagPipelineGenerateEntity,
|
||||||
|
)
|
||||||
|
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from models.enums import UserFrom
|
||||||
|
from models.model import EndUser
|
||||||
|
from models.workflow import Workflow, WorkflowType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunner(WorkflowBasedAppRunner):
|
||||||
|
"""
|
||||||
|
Pipeline Application Runner
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
"""
|
||||||
|
self.application_generate_entity = application_generate_entity
|
||||||
|
self.queue_manager = queue_manager
|
||||||
|
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""
|
||||||
|
Run application
|
||||||
|
"""
|
||||||
|
app_config = self.application_generate_entity.app_config
|
||||||
|
app_config = cast(PipelineConfig, app_config)
|
||||||
|
|
||||||
|
user_id = None
|
||||||
|
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||||
|
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||||
|
if end_user:
|
||||||
|
user_id = end_user.session_id
|
||||||
|
else:
|
||||||
|
user_id = self.application_generate_entity.user_id
|
||||||
|
|
||||||
|
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
|
||||||
|
if not pipeline:
|
||||||
|
raise ValueError("Pipeline not found")
|
||||||
|
|
||||||
|
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError("Workflow not initialized")
|
||||||
|
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
workflow_callbacks: list[WorkflowCallback] = []
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||||
|
|
||||||
|
# if only single iteration run is requested
|
||||||
|
if self.application_generate_entity.single_iteration_run:
|
||||||
|
# if only single iteration run is requested
|
||||||
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||||
|
workflow=workflow,
|
||||||
|
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||||
|
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||||
|
)
|
||||||
|
elif self.application_generate_entity.single_loop_run:
|
||||||
|
# if only single loop run is requested
|
||||||
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||||
|
workflow=workflow,
|
||||||
|
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||||
|
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs = self.application_generate_entity.inputs
|
||||||
|
files = self.application_generate_entity.files
|
||||||
|
|
||||||
|
# Create a variable pool.
|
||||||
|
system_inputs = {
|
||||||
|
SystemVariableKey.FILES: files,
|
||||||
|
SystemVariableKey.USER_ID: user_id,
|
||||||
|
SystemVariableKey.APP_ID: app_config.app_id,
|
||||||
|
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||||
|
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||||
|
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
|
||||||
|
SystemVariableKey.BATCH: self.application_generate_entity.batch,
|
||||||
|
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables=system_inputs,
|
||||||
|
user_inputs=inputs,
|
||||||
|
environment_variables=workflow.environment_variables,
|
||||||
|
conversation_variables=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# init graph
|
||||||
|
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||||
|
|
||||||
|
# RUN WORKFLOW
|
||||||
|
workflow_entry = WorkflowEntry(
|
||||||
|
tenant_id=workflow.tenant_id,
|
||||||
|
app_id=workflow.app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
workflow_type=WorkflowType.value_of(workflow.type),
|
||||||
|
graph=graph,
|
||||||
|
graph_config=workflow.graph_dict,
|
||||||
|
user_id=self.application_generate_entity.user_id,
|
||||||
|
user_from=(
|
||||||
|
UserFrom.ACCOUNT
|
||||||
|
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||||
|
else UserFrom.END_USER
|
||||||
|
),
|
||||||
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
thread_pool_id=self.workflow_thread_pool_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||||
|
|
||||||
|
for event in generator:
|
||||||
|
self._handle_event(workflow_entry, event)
|
||||||
|
|
||||||
|
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
|
||||||
|
"""
|
||||||
|
Get workflow
|
||||||
|
"""
|
||||||
|
# fetch workflow by workflow_id
|
||||||
|
workflow = (
|
||||||
|
db.session.query(Workflow)
|
||||||
|
.filter(
|
||||||
|
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# return workflow
|
||||||
|
return workflow
|
||||||
@ -0,0 +1,37 @@
|
|||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||||
|
return DatasourcePlugin(
|
||||||
|
entity=self.entity,
|
||||||
|
runtime=runtime,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@ -0,0 +1,58 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
tenant_id: str
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
validate the credentials of the provider
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return LocalFileDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@ -0,0 +1,80 @@
|
|||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
GetOnlineDocumentPageContentRequest,
|
||||||
|
GetOnlineDocumentPageContentResponse,
|
||||||
|
GetOnlineDocumentPagesRequest,
|
||||||
|
GetOnlineDocumentPagesResponse,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def _get_online_document_pages(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: GetOnlineDocumentPagesRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> GetOnlineDocumentPagesResponse:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.get_online_document_pages(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_online_document_page_content(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> GetOnlineDocumentPageContentResponse:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.get_online_document_page_content(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||||
|
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||||
|
|
||||||
|
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||||
|
return DatasourcePlugin(
|
||||||
|
entity=self.entity,
|
||||||
|
runtime=runtime,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@ -0,0 +1,50 @@
|
|||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
tenant_id: str
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return DatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@ -0,0 +1,63 @@
|
|||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
GetWebsiteCrawlRequest,
|
||||||
|
GetWebsiteCrawlResponse,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def _get_website_crawl(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: GetWebsiteCrawlRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> GetWebsiteCrawlResponse:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||||
|
|
||||||
|
return manager.invoke_first_step(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||||
|
return DatasourceProviderType.WEBSITE_CRAWL
|
||||||
|
|
||||||
|
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||||
|
return DatasourcePlugin(
|
||||||
|
entity=self.entity,
|
||||||
|
runtime=runtime,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@ -0,0 +1,50 @@
|
|||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
tenant_id: str
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.WEBSITE_CRAWL
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return DatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@ -0,0 +1,113 @@
|
|||||||
|
"""add_pipeline_info_2
|
||||||
|
|
||||||
|
Revision ID: abb18a379e62
|
||||||
|
Revises: b35c3db83d09
|
||||||
|
Create Date: 2025-05-16 16:59:16.423127
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'abb18a379e62'
|
||||||
|
down_revision = 'b35c3db83d09'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table('component_failure_stats')
|
||||||
|
op.drop_table('reliability_data')
|
||||||
|
op.drop_table('maintenance')
|
||||||
|
op.drop_table('operational_data')
|
||||||
|
op.drop_table('component_failure')
|
||||||
|
op.drop_table('tool_providers')
|
||||||
|
op.drop_table('safety_data')
|
||||||
|
op.drop_table('incident_data')
|
||||||
|
with op.batch_alter_table('pipelines', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('mode')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('pipelines', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False))
|
||||||
|
|
||||||
|
op.create_table('incident_data',
|
||||||
|
sa.Column('IncidentID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('IncidentDescription', sa.TEXT(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('IncidentDate', sa.DATE(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('Consequences', sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('ResponseActions', sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('IncidentID', name='incident_data_pkey')
|
||||||
|
)
|
||||||
|
op.create_table('safety_data',
|
||||||
|
sa.Column('SafetyID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('SafetyInspectionDate', sa.DATE(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('SafetyFindings', sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('SafetyIncidentDescription', sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('ComplianceStatus', sa.VARCHAR(length=50), autoincrement=False, nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('SafetyID', name='safety_data_pkey')
|
||||||
|
)
|
||||||
|
op.create_table('tool_providers',
|
||||||
|
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||||
|
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||||
|
)
|
||||||
|
op.create_table('component_failure',
|
||||||
|
sa.Column('FailureID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('Date', sa.DATE(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('RepairAction', sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('FailureID', name='component_failure_pkey'),
|
||||||
|
sa.UniqueConstraint('Date', 'Component', 'FailureMode', 'Cause', 'Technician', name='unique_failure_entry')
|
||||||
|
)
|
||||||
|
op.create_table('operational_data',
|
||||||
|
sa.Column('OperationID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('CraneUsage', sa.INTEGER(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('LoadWeight', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('LoadFrequency', sa.INTEGER(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('EnvironmentalConditions', sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('OperationID', name='operational_data_pkey')
|
||||||
|
)
|
||||||
|
op.create_table('maintenance',
|
||||||
|
sa.Column('MaintenanceID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('MaintenanceType', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('MaintenanceDate', sa.DATE(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('ServiceDescription', sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('PartsReplaced', sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('Technician', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('MaintenanceID', name='maintenance_pkey')
|
||||||
|
)
|
||||||
|
op.create_table('reliability_data',
|
||||||
|
sa.Column('ComponentID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('ComponentName', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('FailureRate', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('ComponentID', name='reliability_data_pkey')
|
||||||
|
)
|
||||||
|
op.create_table('component_failure_stats',
|
||||||
|
sa.Column('StatID', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('Component', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('FailureMode', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('Cause', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('PossibleAction', sa.TEXT(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('Probability', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('MTBF', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('StatID', name='component_failure_stats_pkey')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,109 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||||
|
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||||
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from models.model import Account, App, AppMode, EndUser
|
||||||
|
from models.workflow import Workflow
|
||||||
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineGenerateService:
|
||||||
|
@classmethod
|
||||||
|
def generate(
|
||||||
|
cls,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Pipeline Content Generate
|
||||||
|
:param pipeline: pipeline
|
||||||
|
:param user: user
|
||||||
|
:param args: args
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:param streaming: streaming
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
workflow = cls._get_workflow(pipeline, invoke_from)
|
||||||
|
return PipelineGenerator.convert_to_event_stream(
|
||||||
|
PipelineGenerator().generate(
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow=workflow,
|
||||||
|
user=user,
|
||||||
|
args=args,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
streaming=streaming,
|
||||||
|
call_depth=0,
|
||||||
|
workflow_thread_pool_id=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_max_active_requests(app_model: App) -> int:
|
||||||
|
max_active_requests = app_model.max_active_requests
|
||||||
|
if max_active_requests is None:
|
||||||
|
max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
|
||||||
|
return max_active_requests
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||||
|
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||||
|
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||||
|
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||||
|
AdvancedChatAppGenerator().single_iteration_generate(
|
||||||
|
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||||
|
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||||
|
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||||
|
WorkflowAppGenerator().single_iteration_generate(
|
||||||
|
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||||
|
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
|
||||||
|
return WorkflowAppGenerator.convert_to_event_stream(
|
||||||
|
WorkflowAppGenerator().single_loop_generate(
|
||||||
|
app_model=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_workflow(cls, pipeline: Pipeline, invoke_from: InvokeFrom) -> Workflow:
|
||||||
|
"""
|
||||||
|
Get workflow
|
||||||
|
:param pipeline: pipeline
|
||||||
|
:param invoke_from: invoke from
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError("Workflow not initialized")
|
||||||
|
else:
|
||||||
|
# fetch published workflow by app_model
|
||||||
|
workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline)
|
||||||
|
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError("Workflow not published")
|
||||||
|
|
||||||
|
return workflow
|
||||||
Loading…
Reference in New Issue