|
|
|
|
@ -7,7 +7,7 @@ import threading
|
|
|
|
|
import time
|
|
|
|
|
import uuid
|
|
|
|
|
from collections.abc import Generator, Mapping
|
|
|
|
|
from typing import Any, Literal, Optional, Union, overload
|
|
|
|
|
from typing import Any, Literal, Optional, Union, cast, overload
|
|
|
|
|
|
|
|
|
|
from flask import Flask, current_app
|
|
|
|
|
from pydantic import ValidationError
|
|
|
|
|
@ -24,6 +24,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
|
|
|
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
|
|
|
|
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
|
|
|
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
|
|
|
|
from core.datasource.entities.datasource_entities import (
|
|
|
|
|
DatasourceProviderType,
|
|
|
|
|
OnlineDriveBrowseFilesRequest,
|
|
|
|
|
)
|
|
|
|
|
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
|
|
|
|
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
|
|
|
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
|
|
|
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
|
|
|
|
@ -39,6 +44,7 @@ from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
|
|
|
|
from models.enums import WorkflowRunTriggeredFrom
|
|
|
|
|
from models.model import AppMode
|
|
|
|
|
from services.dataset_service import DocumentService
|
|
|
|
|
from services.datasource_provider_service import DatasourceProviderService
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
@ -105,13 +111,13 @@ class PipelineGenerator(BaseAppGenerator):
|
|
|
|
|
inputs: Mapping[str, Any] = args["inputs"]
|
|
|
|
|
start_node_id: str = args["start_node_id"]
|
|
|
|
|
datasource_type: str = args["datasource_type"]
|
|
|
|
|
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
|
|
|
|
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
|
|
|
|
|
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
|
|
|
|
|
)
|
|
|
|
|
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
|
|
|
|
|
# convert to app config
|
|
|
|
|
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
|
|
|
|
pipeline=pipeline,
|
|
|
|
|
workflow=workflow,
|
|
|
|
|
start_node_id=start_node_id
|
|
|
|
|
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
|
|
|
|
|
)
|
|
|
|
|
documents = []
|
|
|
|
|
if invoke_from == InvokeFrom.PUBLISHED:
|
|
|
|
|
@ -353,9 +359,9 @@ class PipelineGenerator(BaseAppGenerator):
|
|
|
|
|
raise ValueError("inputs is required")
|
|
|
|
|
|
|
|
|
|
# convert to app config
|
|
|
|
|
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
|
|
|
|
|
workflow=workflow,
|
|
|
|
|
start_node_id=args.get("start_node_id","shared"))
|
|
|
|
|
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
|
|
|
|
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
dataset = pipeline.dataset
|
|
|
|
|
if not dataset:
|
|
|
|
|
@ -440,9 +446,9 @@ class PipelineGenerator(BaseAppGenerator):
|
|
|
|
|
raise ValueError("Pipeline dataset is required")
|
|
|
|
|
|
|
|
|
|
# convert to app config
|
|
|
|
|
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
|
|
|
|
|
workflow=workflow,
|
|
|
|
|
start_node_id=args.get("start_node_id","shared"))
|
|
|
|
|
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
|
|
|
|
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# init application generate entity
|
|
|
|
|
application_generate_entity = RagPipelineGenerateEntity(
|
|
|
|
|
@ -633,3 +639,107 @@ class PipelineGenerator(BaseAppGenerator):
|
|
|
|
|
if doc_metadata:
|
|
|
|
|
document.doc_metadata = doc_metadata
|
|
|
|
|
return document
|
|
|
|
|
|
|
|
|
|
def _format_datasource_info_list(
|
|
|
|
|
self,
|
|
|
|
|
datasource_type: str,
|
|
|
|
|
datasource_info_list: list[Mapping[str, Any]],
|
|
|
|
|
pipeline: Pipeline,
|
|
|
|
|
workflow: Workflow,
|
|
|
|
|
start_node_id: str,
|
|
|
|
|
user: Union[Account, EndUser],
|
|
|
|
|
) -> list[Mapping[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|
Format datasource info list.
|
|
|
|
|
"""
|
|
|
|
|
if datasource_type == "online_drive":
|
|
|
|
|
all_files = []
|
|
|
|
|
datasource_node_data = None
|
|
|
|
|
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
|
|
|
|
for datasource_node in datasource_nodes:
|
|
|
|
|
if datasource_node.get("id") == start_node_id:
|
|
|
|
|
datasource_node_data = datasource_node.get("data", {})
|
|
|
|
|
break
|
|
|
|
|
if not datasource_node_data:
|
|
|
|
|
raise ValueError("Datasource node data not found")
|
|
|
|
|
|
|
|
|
|
from core.datasource.datasource_manager import DatasourceManager
|
|
|
|
|
|
|
|
|
|
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
|
|
|
|
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
|
|
|
|
|
datasource_name=datasource_node_data.get("datasource_name"),
|
|
|
|
|
tenant_id=pipeline.tenant_id,
|
|
|
|
|
datasource_type=DatasourceProviderType(datasource_type),
|
|
|
|
|
)
|
|
|
|
|
datasource_provider_service = DatasourceProviderService()
|
|
|
|
|
credentials = datasource_provider_service.get_real_datasource_credentials(
|
|
|
|
|
tenant_id=pipeline.tenant_id,
|
|
|
|
|
provider=datasource_node_data.get("provider_name"),
|
|
|
|
|
plugin_id=datasource_node_data.get("plugin_id"),
|
|
|
|
|
)
|
|
|
|
|
if credentials:
|
|
|
|
|
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
|
|
|
|
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
|
|
|
|
|
|
|
|
|
for datasource_info in datasource_info_list:
|
|
|
|
|
if datasource_info.get("key") and datasource_info.get("key", "").endswith("/"):
|
|
|
|
|
# get all files in the folder
|
|
|
|
|
self._get_files_in_folder(
|
|
|
|
|
datasource_runtime,
|
|
|
|
|
datasource_info.get("key", ""),
|
|
|
|
|
None,
|
|
|
|
|
datasource_info.get("bucket", None),
|
|
|
|
|
user.id,
|
|
|
|
|
all_files,
|
|
|
|
|
datasource_info,
|
|
|
|
|
)
|
|
|
|
|
return all_files
|
|
|
|
|
else:
|
|
|
|
|
return datasource_info_list
|
|
|
|
|
|
|
|
|
|
def _get_files_in_folder(
|
|
|
|
|
self,
|
|
|
|
|
datasource_runtime: OnlineDriveDatasourcePlugin,
|
|
|
|
|
prefix: str,
|
|
|
|
|
start_after: Optional[str],
|
|
|
|
|
bucket: Optional[str],
|
|
|
|
|
user_id: str,
|
|
|
|
|
all_files: list,
|
|
|
|
|
datasource_info: Mapping[str, Any],
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Get files in a folder.
|
|
|
|
|
"""
|
|
|
|
|
result_generator = datasource_runtime.online_drive_browse_files(
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
request=OnlineDriveBrowseFilesRequest(
|
|
|
|
|
bucket=bucket,
|
|
|
|
|
prefix=prefix,
|
|
|
|
|
max_keys=20,
|
|
|
|
|
start_after=start_after,
|
|
|
|
|
),
|
|
|
|
|
provider_type=datasource_runtime.datasource_provider_type(),
|
|
|
|
|
)
|
|
|
|
|
is_truncated = False
|
|
|
|
|
last_file_key = None
|
|
|
|
|
for result in result_generator:
|
|
|
|
|
for files in result.result:
|
|
|
|
|
for file in files.files:
|
|
|
|
|
if file.key.endswith("/"):
|
|
|
|
|
self._get_files_in_folder(
|
|
|
|
|
datasource_runtime, file.key, None, bucket, user_id, all_files, datasource_info
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
all_files.append(
|
|
|
|
|
{
|
|
|
|
|
"key": file.key,
|
|
|
|
|
"bucket": bucket,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
last_file_key = file.key
|
|
|
|
|
is_truncated = files.is_truncated
|
|
|
|
|
|
|
|
|
|
if is_truncated:
|
|
|
|
|
self._get_files_in_folder(
|
|
|
|
|
datasource_runtime, prefix, last_file_key, bucket, user_id, all_files, datasource_info
|
|
|
|
|
)
|
|
|
|
|
|