|
|
|
|
@ -8,7 +8,7 @@ from typing import Any, Optional, cast
|
|
|
|
|
from uuid import uuid4
|
|
|
|
|
|
|
|
|
|
from flask_login import current_user
|
|
|
|
|
from sqlalchemy import or_, select
|
|
|
|
|
from sqlalchemy import func, or_, select
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
import contexts
|
|
|
|
|
@ -78,15 +78,20 @@ class RagPipelineService:
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
|
|
|
|
|
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
|
|
|
|
|
"""
|
|
|
|
|
Get pipeline template detail.
|
|
|
|
|
:param template_id: template id
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
|
|
|
|
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
|
|
|
|
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
|
|
|
|
if type == "built-in":
|
|
|
|
|
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
|
|
|
|
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
|
|
|
|
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
|
|
|
|
else:
|
|
|
|
|
mode = "customized"
|
|
|
|
|
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
|
|
|
|
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
@ -930,5 +935,24 @@ class RagPipelineService:
|
|
|
|
|
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
|
|
|
|
if not workflow:
|
|
|
|
|
raise ValueError("Workflow not found")
|
|
|
|
|
dataset = pipeline.dataset
|
|
|
|
|
if not dataset:
|
|
|
|
|
raise ValueError("Dataset not found")
|
|
|
|
|
|
|
|
|
|
max_position = db.session.query(func.max(PipelineCustomizedTemplate.position)).filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id).scalar()
|
|
|
|
|
|
|
|
|
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
|
|
|
|
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
|
|
|
|
|
|
|
|
|
|
pipeline_customized_template = PipelineCustomizedTemplate(
|
|
|
|
|
name=args.get("name"),
|
|
|
|
|
description=args.get("description"),
|
|
|
|
|
icon=args.get("icon_info"),
|
|
|
|
|
tenant_id=pipeline.tenant_id,
|
|
|
|
|
yaml_content=dsl,
|
|
|
|
|
position=max_position + 1 if max_position else 1,
|
|
|
|
|
chunk_structure=dataset.chunk_structure,
|
|
|
|
|
language="en-US",
|
|
|
|
|
)
|
|
|
|
|
db.session.add(pipeline_customized_template)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|