|
|
|
|
@ -1,5 +1,7 @@
|
|
|
|
|
import base64
|
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
|
import hashlib
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import uuid
|
|
|
|
|
from collections.abc import Mapping
|
|
|
|
|
@ -31,13 +33,12 @@ from extensions.ext_redis import redis_client
|
|
|
|
|
from factories import variable_factory
|
|
|
|
|
from models import Account
|
|
|
|
|
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
|
|
|
|
|
from models.workflow import Workflow
|
|
|
|
|
from models.workflow import Workflow, WorkflowType
|
|
|
|
|
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
|
|
|
|
KnowledgeConfiguration,
|
|
|
|
|
RagPipelineDatasetCreateEntity,
|
|
|
|
|
)
|
|
|
|
|
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
|
|
|
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
@ -206,12 +207,12 @@ class RagPipelineDslService:
|
|
|
|
|
status = _check_version_compatibility(imported_version)
|
|
|
|
|
|
|
|
|
|
# Extract app data
|
|
|
|
|
pipeline_data = data.get("pipeline")
|
|
|
|
|
pipeline_data = data.get("rag_pipeline")
|
|
|
|
|
if not pipeline_data:
|
|
|
|
|
return RagPipelineImportInfo(
|
|
|
|
|
id=import_id,
|
|
|
|
|
status=ImportStatus.FAILED,
|
|
|
|
|
error="Missing pipeline data in YAML content",
|
|
|
|
|
error="Missing rag_pipeline data in YAML content",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# If app_id is provided, check if it exists
|
|
|
|
|
@ -256,7 +257,7 @@ class RagPipelineDslService:
|
|
|
|
|
if dependencies:
|
|
|
|
|
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
|
|
|
|
|
|
|
|
|
|
# Create or update app
|
|
|
|
|
# Create or update pipeline
|
|
|
|
|
pipeline = self._create_or_update_pipeline(
|
|
|
|
|
pipeline=pipeline,
|
|
|
|
|
data=data,
|
|
|
|
|
@ -278,7 +279,9 @@ class RagPipelineDslService:
|
|
|
|
|
if node.get("data", {}).get("type") == "knowledge_index":
|
|
|
|
|
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
|
|
|
|
|
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
|
|
|
|
if not dataset:
|
|
|
|
|
if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure:
|
|
|
|
|
raise ValueError("Chunk structure is not compatible with the published pipeline")
|
|
|
|
|
else:
|
|
|
|
|
dataset = Dataset(
|
|
|
|
|
tenant_id=account.current_tenant_id,
|
|
|
|
|
name=name,
|
|
|
|
|
@ -295,11 +298,6 @@ class RagPipelineDslService:
|
|
|
|
|
runtime_mode="rag_pipeline",
|
|
|
|
|
chunk_structure=knowledge_configuration.chunk_structure,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique
|
|
|
|
|
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
|
|
|
|
|
dataset.runtime_mode = "rag_pipeline"
|
|
|
|
|
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
|
|
|
|
if knowledge_configuration.index_method.indexing_technique == "high_quality":
|
|
|
|
|
dataset_collection_binding = (
|
|
|
|
|
db.session.query(DatasetCollectionBinding)
|
|
|
|
|
@ -540,33 +538,6 @@ class RagPipelineDslService:
|
|
|
|
|
icon_type = "emoji"
|
|
|
|
|
icon = str(pipeline_data.get("icon", ""))
|
|
|
|
|
|
|
|
|
|
if pipeline:
|
|
|
|
|
# Update existing pipeline
|
|
|
|
|
pipeline.name = pipeline_data.get("name", pipeline.name)
|
|
|
|
|
pipeline.description = pipeline_data.get("description", pipeline.description)
|
|
|
|
|
pipeline.updated_by = account.id
|
|
|
|
|
else:
|
|
|
|
|
if account.current_tenant_id is None:
|
|
|
|
|
raise ValueError("Current tenant is not set")
|
|
|
|
|
|
|
|
|
|
# Create new app
|
|
|
|
|
pipeline = Pipeline()
|
|
|
|
|
pipeline.id = str(uuid4())
|
|
|
|
|
pipeline.tenant_id = account.current_tenant_id
|
|
|
|
|
pipeline.name = pipeline_data.get("name", "")
|
|
|
|
|
pipeline.description = pipeline_data.get("description", "")
|
|
|
|
|
pipeline.created_by = account.id
|
|
|
|
|
pipeline.updated_by = account.id
|
|
|
|
|
|
|
|
|
|
self._session.add(pipeline)
|
|
|
|
|
self._session.commit()
|
|
|
|
|
# save dependencies
|
|
|
|
|
if dependencies:
|
|
|
|
|
redis_client.setex(
|
|
|
|
|
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}",
|
|
|
|
|
IMPORT_INFO_REDIS_EXPIRY,
|
|
|
|
|
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Initialize pipeline based on mode
|
|
|
|
|
workflow_data = data.get("workflow")
|
|
|
|
|
@ -583,12 +554,7 @@ class RagPipelineDslService:
|
|
|
|
|
]
|
|
|
|
|
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
|
|
|
|
|
|
|
|
|
|
rag_pipeline_service = RagPipelineService()
|
|
|
|
|
current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
|
|
|
|
if current_draft_workflow:
|
|
|
|
|
unique_hash = current_draft_workflow.unique_hash
|
|
|
|
|
else:
|
|
|
|
|
unique_hash = None
|
|
|
|
|
|
|
|
|
|
graph = workflow_data.get("graph", {})
|
|
|
|
|
for node in graph.get("nodes", []):
|
|
|
|
|
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
|
|
|
|
@ -599,20 +565,78 @@ class RagPipelineDslService:
|
|
|
|
|
if (
|
|
|
|
|
decrypted_id := self.decrypt_dataset_id(
|
|
|
|
|
encrypted_data=dataset_id,
|
|
|
|
|
tenant_id=pipeline.tenant_id,
|
|
|
|
|
tenant_id=account.current_tenant_id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
rag_pipeline_service.sync_draft_workflow(
|
|
|
|
|
pipeline=pipeline,
|
|
|
|
|
graph=workflow_data.get("graph", {}),
|
|
|
|
|
unique_hash=unique_hash,
|
|
|
|
|
account=account,
|
|
|
|
|
environment_variables=environment_variables,
|
|
|
|
|
conversation_variables=conversation_variables,
|
|
|
|
|
rag_pipeline_variables=rag_pipeline_variables_list,
|
|
|
|
|
|
|
|
|
|
if pipeline:
|
|
|
|
|
# Update existing pipeline
|
|
|
|
|
pipeline.name = pipeline_data.get("name", pipeline.name)
|
|
|
|
|
pipeline.description = pipeline_data.get("description", pipeline.description)
|
|
|
|
|
pipeline.updated_by = account.id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
if account.current_tenant_id is None:
|
|
|
|
|
raise ValueError("Current tenant is not set")
|
|
|
|
|
|
|
|
|
|
# Create new app
|
|
|
|
|
pipeline = Pipeline()
|
|
|
|
|
pipeline.id = str(uuid4())
|
|
|
|
|
pipeline.tenant_id = account.current_tenant_id
|
|
|
|
|
pipeline.name = pipeline_data.get("name", "")
|
|
|
|
|
pipeline.description = pipeline_data.get("description", "")
|
|
|
|
|
pipeline.created_by = account.id
|
|
|
|
|
pipeline.updated_by = account.id
|
|
|
|
|
|
|
|
|
|
self._session.add(pipeline)
|
|
|
|
|
self._session.commit()
|
|
|
|
|
# save dependencies
|
|
|
|
|
if dependencies:
|
|
|
|
|
redis_client.setex(
|
|
|
|
|
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}",
|
|
|
|
|
IMPORT_INFO_REDIS_EXPIRY,
|
|
|
|
|
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
|
|
|
|
|
)
|
|
|
|
|
workflow = (
|
|
|
|
|
db.session.query(Workflow)
|
|
|
|
|
.filter(
|
|
|
|
|
Workflow.tenant_id == pipeline.tenant_id,
|
|
|
|
|
Workflow.app_id == pipeline.id,
|
|
|
|
|
Workflow.version == "draft",
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# create draft workflow if not found
|
|
|
|
|
if not workflow:
|
|
|
|
|
workflow = Workflow(
|
|
|
|
|
tenant_id=pipeline.tenant_id,
|
|
|
|
|
app_id=pipeline.id,
|
|
|
|
|
features="{}",
|
|
|
|
|
type=WorkflowType.RAG_PIPELINE.value,
|
|
|
|
|
version="draft",
|
|
|
|
|
graph=json.dumps(graph),
|
|
|
|
|
created_by=account.id,
|
|
|
|
|
environment_variables=environment_variables,
|
|
|
|
|
conversation_variables=conversation_variables,
|
|
|
|
|
rag_pipeline_variables=rag_pipeline_variables_list,
|
|
|
|
|
)
|
|
|
|
|
db.session.add(workflow)
|
|
|
|
|
db.session.flush()
|
|
|
|
|
pipeline.workflow_id = workflow.id
|
|
|
|
|
else:
|
|
|
|
|
workflow.graph = json.dumps(graph)
|
|
|
|
|
workflow.updated_by = account.id
|
|
|
|
|
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
|
|
|
|
workflow.environment_variables = environment_variables
|
|
|
|
|
workflow.conversation_variables = conversation_variables
|
|
|
|
|
workflow.rag_pipeline_variables = rag_pipeline_variables_list
|
|
|
|
|
# commit db session changes
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return pipeline
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
@ -623,16 +647,19 @@ class RagPipelineDslService:
|
|
|
|
|
:param include_secret: Whether include secret variable
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
dataset = pipeline.dataset
|
|
|
|
|
if not dataset:
|
|
|
|
|
raise ValueError("Missing dataset for rag pipeline")
|
|
|
|
|
icon_info = dataset.icon_info
|
|
|
|
|
export_data = {
|
|
|
|
|
"version": CURRENT_DSL_VERSION,
|
|
|
|
|
"kind": "rag_pipeline",
|
|
|
|
|
"pipeline": {
|
|
|
|
|
"name": pipeline.name,
|
|
|
|
|
"mode": pipeline.mode,
|
|
|
|
|
"icon": "🤖" if pipeline.icon_type == "image" else pipeline.icon,
|
|
|
|
|
"icon_background": "#FFEAD5" if pipeline.icon_type == "image" else pipeline.icon_background,
|
|
|
|
|
"icon": icon_info.get("icon", "📙") if icon_info else "📙",
|
|
|
|
|
"icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji",
|
|
|
|
|
"icon_background": icon_info.get("icon_background", "#FFEAD5") if icon_info else "#FFEAD5",
|
|
|
|
|
"description": pipeline.description,
|
|
|
|
|
"use_icon_as_answer_icon": pipeline.use_icon_as_answer_icon,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -647,8 +674,16 @@ class RagPipelineDslService:
|
|
|
|
|
:param export_data: export data
|
|
|
|
|
:param pipeline: Pipeline instance
|
|
|
|
|
"""
|
|
|
|
|
rag_pipeline_service = RagPipelineService()
|
|
|
|
|
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
|
|
|
|
|
|
|
|
|
workflow = (
|
|
|
|
|
db.session.query(Workflow)
|
|
|
|
|
.filter(
|
|
|
|
|
Workflow.tenant_id == pipeline.tenant_id,
|
|
|
|
|
Workflow.app_id == pipeline.id,
|
|
|
|
|
Workflow.version == "draft",
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if not workflow:
|
|
|
|
|
raise ValueError("Missing draft workflow configuration, please check.")
|
|
|
|
|
|
|
|
|
|
@ -855,14 +890,6 @@ class RagPipelineDslService:
|
|
|
|
|
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
dataset = Dataset(
|
|
|
|
|
name=rag_pipeline_dataset_create_entity.name,
|
|
|
|
|
description=rag_pipeline_dataset_create_entity.description,
|
|
|
|
|
permission=rag_pipeline_dataset_create_entity.permission,
|
|
|
|
|
provider="vendor",
|
|
|
|
|
runtime_mode="rag-pipeline",
|
|
|
|
|
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
|
|
|
|
|
)
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
|
|
|
|
account = cast(Account, current_user)
|
|
|
|
|
@ -870,11 +897,11 @@ class RagPipelineDslService:
|
|
|
|
|
account=account,
|
|
|
|
|
import_mode=ImportMode.YAML_CONTENT.value,
|
|
|
|
|
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
|
|
|
|
dataset=dataset,
|
|
|
|
|
dataset=None,
|
|
|
|
|
)
|
|
|
|
|
return {
|
|
|
|
|
"id": rag_pipeline_import_info.id,
|
|
|
|
|
"dataset_id": dataset.id,
|
|
|
|
|
"dataset_id": rag_pipeline_import_info.dataset_id,
|
|
|
|
|
"pipeline_id": rag_pipeline_import_info.pipeline_id,
|
|
|
|
|
"status": rag_pipeline_import_info.status,
|
|
|
|
|
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
|
|
|
|
|
|