|
|
|
|
@ -12,8 +12,10 @@ from controllers.console.wraps import (
|
|
|
|
|
)
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
|
from libs.login import login_required
|
|
|
|
|
from models.dataset import Pipeline, PipelineCustomizedTemplate
|
|
|
|
|
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
|
|
|
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
|
|
|
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
@ -99,7 +101,14 @@ class CustomizedPipelineTemplateApi(Resource):
|
|
|
|
|
@enterprise_license_required
|
|
|
|
|
def post(self, template_id: str):
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
dsl = RagPipelineService.export_template_rag_pipeline_dsl(template_id)
|
|
|
|
|
template = session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
|
|
|
|
if not template:
|
|
|
|
|
raise ValueError("Customized pipeline template not found.")
|
|
|
|
|
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
|
|
|
|
|
if not pipeline:
|
|
|
|
|
raise ValueError("Pipeline not found.")
|
|
|
|
|
|
|
|
|
|
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
|
|
|
|
|
return {"data": dsl}, 200
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|