feat/datasource
jyong 12 months ago
parent a025db137d
commit e7c48c0b69

@ -1,5 +1,6 @@
import logging import logging
import yaml
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -12,10 +13,9 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import login_required
from models.dataset import Pipeline, PipelineCustomizedTemplate from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -84,8 +84,8 @@ class CustomizedPipelineTemplateApi(Resource):
) )
args = parser.parse_args() args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args) pipeline_template_info = PipelineTemplateInfoEntity(**args)
pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return pipeline_template, 200 return 200
@setup_required @setup_required
@login_required @login_required
@ -106,13 +106,41 @@ class CustomizedPipelineTemplateApi(Resource):
) )
if not template: if not template:
raise ValueError("Customized pipeline template not found.") raise ValueError("Customized pipeline template not found.")
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found.")
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) dsl = yaml.safe_load(template.yaml_content)
return {"data": dsl}, 200 return {"data": dsl}, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, pipeline_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()
RagPipelineService.publish_customized_pipeline_template(pipeline_id, args)
return 200
api.add_resource( api.add_resource(
PipelineTemplateListApi, PipelineTemplateListApi,

@ -20,11 +20,11 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
@ -32,6 +32,7 @@ from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchem
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import dataset_and_document_fields
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline from models.dataset import Document, Pipeline
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
@ -54,7 +55,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: Literal[True], streaming: Literal[True],
call_depth: int, call_depth: int,
workflow_thread_pool_id: Optional[str], workflow_thread_pool_id: Optional[str],
) -> Generator[Mapping | str, None, None] | None: ... ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
@overload @overload
def generate( def generate(
@ -101,23 +102,18 @@ class PipelineGenerator(BaseAppGenerator):
pipeline=pipeline, pipeline=pipeline,
workflow=workflow, workflow=workflow,
) )
# Add null check for dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"] inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"] start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"] datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
documents = []
for datasource_info in datasource_info_list:
workflow_run_id = str(uuid.uuid4())
document_id = None
# Add null check for dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
if invoke_from == InvokeFrom.PUBLISHED: if invoke_from == InvokeFrom.PUBLISHED:
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id) position = DocumentService.get_documents_position(dataset.id)
document = self._build_document( document = self._build_document(
tenant_id=pipeline.tenant_id, tenant_id=pipeline.tenant_id,
@ -132,9 +128,15 @@ class PipelineGenerator(BaseAppGenerator):
document_form=dataset.chunk_structure, document_form=dataset.chunk_structure,
) )
db.session.add(document) db.session.add(document)
documents.append(document)
db.session.commit() db.session.commit()
document_id = document.id
# init application generate entity # run in child thread
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = None
if invoke_from == InvokeFrom.PUBLISHED:
document_id = documents[i].id
application_generate_entity = RagPipelineGenerateEntity( application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=pipeline_config, app_config=pipeline_config,
@ -159,7 +161,6 @@ class PipelineGenerator(BaseAppGenerator):
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER: if invoke_from == InvokeFrom.DEBUGGER:
@ -183,6 +184,7 @@ class PipelineGenerator(BaseAppGenerator):
) )
if invoke_from == InvokeFrom.DEBUGGER: if invoke_from == InvokeFrom.DEBUGGER:
return self._generate( return self._generate(
flask_app=current_app._get_current_object(),# type: ignore
pipeline=pipeline, pipeline=pipeline,
workflow=workflow, workflow=workflow,
user=user, user=user,
@ -194,21 +196,47 @@ class PipelineGenerator(BaseAppGenerator):
workflow_thread_pool_id=workflow_thread_pool_id, workflow_thread_pool_id=workflow_thread_pool_id,
) )
else: else:
self._generate( # run in child thread
pipeline=pipeline, thread = threading.Thread(
workflow=workflow, target=self._generate,
user=user, kwargs={
application_generate_entity=application_generate_entity, "flask_app": current_app._get_current_object(), # type: ignore
invoke_from=invoke_from, "pipeline": pipeline,
workflow_execution_repository=workflow_execution_repository, "workflow": workflow,
workflow_node_execution_repository=workflow_node_execution_repository, "user": user,
streaming=streaming, "application_generate_entity": application_generate_entity,
workflow_thread_pool_id=workflow_thread_pool_id, "invoke_from": invoke_from,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"streaming": streaming,
"workflow_thread_pool_id": workflow_thread_pool_id,
},
) )
thread.start()
# return batch, dataset, documents
return {
"batch": batch,
"dataset": PipelineDataset(
id=dataset.id,
name=dataset.name,
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [PipelineDocument(
id=document.id,
position=document.position,
data_source_info=document.data_source_info,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump() for document in documents
]
}
def _generate( def _generate(
self, self,
*, *,
flask_app: Flask,
pipeline: Pipeline, pipeline: Pipeline,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
@ -232,6 +260,8 @@ class PipelineGenerator(BaseAppGenerator):
:param streaming: is stream :param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id :param workflow_thread_pool_id: workflow thread pool id
""" """
print(user.id)
with flask_app.app_context():
# init queue manager # init queue manager
queue_manager = PipelineQueueManager( queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id, task_id=application_generate_entity.task_id,
@ -317,7 +347,6 @@ class PipelineGenerator(BaseAppGenerator):
call_depth=0, call_depth=0,
workflow_run_id=str(uuid.uuid4()), workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository # Create workflow node execution repository
@ -338,6 +367,7 @@ class PipelineGenerator(BaseAppGenerator):
) )
return self._generate( return self._generate(
flask_app=current_app._get_current_object(),# type: ignore
pipeline=pipeline, pipeline=pipeline,
workflow=workflow, workflow=workflow,
user=user, user=user,
@ -399,7 +429,6 @@ class PipelineGenerator(BaseAppGenerator):
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_run_id=str(uuid.uuid4()), workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -421,6 +450,7 @@ class PipelineGenerator(BaseAppGenerator):
) )
return self._generate( return self._generate(
flask_app=current_app._get_current_object(),# type: ignore
pipeline=pipeline, pipeline=pipeline,
workflow=workflow, workflow=workflow,
user=user, user=user,

@ -17,3 +17,26 @@ class IndexingEstimate(BaseModel):
total_segments: int total_segments: int
preview: list[PreviewDetail] preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None qa_preview: Optional[list[QAPreviewDetail]] = None
class PipelineDataset(BaseModel):
id: str
name: str
description: str
chunk_structure: str
class PipelineDocument(BaseModel):
id: str
position: int
data_source_info: dict
name: str
indexing_status: str
error: str
enabled: bool
class PipelineGenerateResponse(BaseModel):
batch: str
dataset: PipelineDataset
documents: list[PipelineDocument]

@ -253,6 +253,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self, self,
workflow_run_id: str, workflow_run_id: str,
order_config: Optional[OrderConfig] = None, order_config: Optional[OrderConfig] = None,
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) -> Sequence[WorkflowNodeExecution]: ) -> Sequence[WorkflowNodeExecution]:
""" """
Retrieve all WorkflowNodeExecution database models for a specific workflow run. Retrieve all WorkflowNodeExecution database models for a specific workflow run.
@ -274,7 +275,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
stmt = select(WorkflowNodeExecution).where( stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.workflow_run_id == workflow_run_id, WorkflowNodeExecution.workflow_run_id == workflow_run_id,
WorkflowNodeExecution.tenant_id == self._tenant_id, WorkflowNodeExecution.tenant_id == self._tenant_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, WorkflowNodeExecution.triggered_from == triggered_from,
) )
if self._app_id: if self._app_id:
@ -308,6 +309,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self, self,
workflow_run_id: str, workflow_run_id: str,
order_config: Optional[OrderConfig] = None, order_config: Optional[OrderConfig] = None,
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) -> Sequence[NodeExecution]: ) -> Sequence[NodeExecution]:
""" """
Retrieve all NodeExecution instances for a specific workflow run. Retrieve all NodeExecution instances for a specific workflow run.
@ -325,7 +327,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
A list of NodeExecution instances A list of NodeExecution instances
""" """
# Get the database models using the new method # Get the database models using the new method
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config) db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from)
# Convert database models to domain models # Convert database models to domain models
domain_models = [] domain_models = []

@ -87,6 +87,7 @@ dataset_detail_fields = {
"runtime_mode": fields.String, "runtime_mode": fields.String,
"chunk_structure": fields.String, "chunk_structure": fields.String,
"icon_info": fields.Nested(icon_info_fields), "icon_info": fields.Nested(icon_info_fields),
"is_published": fields.Boolean,
} }
dataset_query_detail_fields = { dataset_query_detail_fields = {

@ -152,6 +152,8 @@ class Dataset(Base):
@property @property
def doc_form(self): def doc_form(self):
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).filter(Document.dataset_id == self.id).first() document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
if document: if document:
return document.doc_form return document.doc_form
@ -206,6 +208,13 @@ class Dataset(Base):
"external_knowledge_api_name": external_knowledge_api.name, "external_knowledge_api_name": external_knowledge_api.name,
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
} }
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
if pipeline:
return pipeline.is_published
return False
@property @property
def doc_metadata(self): def doc_metadata(self):
@ -1154,10 +1163,11 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),) __table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
pipeline_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False) description = db.Column(db.Text, nullable=False)
chunk_structure = db.Column(db.String(255), nullable=False)
icon = db.Column(db.JSON, nullable=False) icon = db.Column(db.JSON, nullable=False)
yaml_content = db.Column(db.Text, nullable=False)
copyright = db.Column(db.String(255), nullable=False) copyright = db.Column(db.String(255), nullable=False)
privacy_policy = db.Column(db.String(255), nullable=False) privacy_policy = db.Column(db.String(255), nullable=False)
position = db.Column(db.Integer, nullable=False) position = db.Column(db.Integer, nullable=False)
@ -1166,9 +1176,6 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def pipeline(self):
return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
@ -1180,11 +1187,12 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
pipeline_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False) description = db.Column(db.Text, nullable=False)
chunk_structure = db.Column(db.String(255), nullable=False)
icon = db.Column(db.JSON, nullable=False) icon = db.Column(db.JSON, nullable=False)
position = db.Column(db.Integer, nullable=False) position = db.Column(db.Integer, nullable=False)
yaml_content = db.Column(db.Text, nullable=False)
install_count = db.Column(db.Integer, nullable=False, default=0) install_count = db.Column(db.Integer, nullable=False, default=0)
language = db.Column(db.String(255), nullable=False) language = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

@ -23,8 +23,8 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
result = self.fetch_pipeline_templates_from_builtin(language) result = self.fetch_pipeline_templates_from_builtin(language)
return result return result
def get_pipeline_template_detail(self, pipeline_id: str): def get_pipeline_template_detail(self, template_id: str):
result = self.fetch_pipeline_template_detail_from_builtin(pipeline_id) result = self.fetch_pipeline_template_detail_from_builtin(template_id)
return result return result
@classmethod @classmethod
@ -54,11 +54,11 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return builtin_data.get("pipeline_templates", {}).get(language, {}) return builtin_data.get("pipeline_templates", {}).get(language, {})
@classmethod @classmethod
def fetch_pipeline_template_detail_from_builtin(cls, pipeline_id: str) -> Optional[dict]: def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]:
""" """
Fetch pipeline template detail from builtin. Fetch pipeline template detail from builtin.
:param pipeline_id: Pipeline ID :param template_id: Template ID
:return: :return:
""" """
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(pipeline_id) return builtin_data.get("pipeline_templates", {}).get(template_id)

@ -1,12 +1,13 @@
from typing import Optional from typing import Optional
from flask_login import current_user from flask_login import current_user
import yaml
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Pipeline, PipelineCustomizedTemplate from models.dataset import PipelineCustomizedTemplate
from services.app_dsl_service import AppDslService
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
@ -35,13 +36,26 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param language: language :param language: language
:return: :return:
""" """
pipeline_templates = ( pipeline_customized_templates = (
db.session.query(PipelineCustomizedTemplate) db.session.query(PipelineCustomizedTemplate)
.filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.all() .all()
) )
recommended_pipelines_results = []
for pipeline_customized_template in pipeline_customized_templates:
recommended_pipeline_result = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
"description": pipeline_customized_template.description,
"icon": pipeline_customized_template.icon,
"position": pipeline_customized_template.position,
"chunk_structure": pipeline_customized_template.chunk_structure,
}
recommended_pipelines_results.append(recommended_pipeline_result)
return {"pipeline_templates": recommended_pipelines_results}
return {"pipeline_templates": pipeline_templates}
@classmethod @classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
@ -57,15 +71,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
if not pipeline_template: if not pipeline_template:
return None return None
# get pipeline detail
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
if not pipeline or not pipeline.is_public:
return None
return { return {
"id": pipeline.id, "id": pipeline_template.id,
"name": pipeline.name, "name": pipeline_template.name,
"icon": pipeline.icon, "icon": pipeline_template.icon,
"mode": pipeline.mode, "export_data": yaml.safe_load(pipeline_template.yaml_content),
"export_data": AppDslService.export_dsl(app_model=pipeline),
} }

@ -1,7 +1,9 @@
from typing import Optional from typing import Optional
import yaml
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate from models.dataset import PipelineBuiltInTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
@ -36,23 +38,17 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
recommended_pipelines_results = [] recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates: for pipeline_built_in_template in pipeline_built_in_templates:
pipeline_model: Pipeline | None = pipeline_built_in_template.pipeline
if not pipeline_model:
continue
recommended_pipeline_result = { recommended_pipeline_result = {
"id": pipeline_built_in_template.id, "id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name, "name": pipeline_built_in_template.name,
"pipeline_id": pipeline_model.id,
"description": pipeline_built_in_template.description, "description": pipeline_built_in_template.description,
"icon": pipeline_built_in_template.icon, "icon": pipeline_built_in_template.icon,
"copyright": pipeline_built_in_template.copyright, "copyright": pipeline_built_in_template.copyright,
"privacy_policy": pipeline_built_in_template.privacy_policy, "privacy_policy": pipeline_built_in_template.privacy_policy,
"position": pipeline_built_in_template.position, "position": pipeline_built_in_template.position,
"chunk_structure": pipeline_built_in_template.chunk_structure,
} }
dataset: Dataset | None = pipeline_model.dataset
if dataset:
recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
recommended_pipelines_results.append(recommended_pipeline_result) recommended_pipelines_results.append(recommended_pipeline_result)
return {"pipeline_templates": recommended_pipelines_results} return {"pipeline_templates": recommended_pipelines_results}
@ -64,8 +60,6 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param pipeline_id: Pipeline ID :param pipeline_id: Pipeline ID
:return: :return:
""" """
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
# is in public recommended list # is in public recommended list
pipeline_template = ( pipeline_template = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()
@ -74,19 +68,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
if not pipeline_template: if not pipeline_template:
return None return None
# get pipeline detail
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
if not pipeline or not pipeline.is_public:
return None
dataset: Dataset | None = pipeline.dataset
if not dataset:
return None
return { return {
"id": pipeline.id, "id": pipeline_template.id,
"name": pipeline.name, "name": pipeline_template.name,
"icon": pipeline_template.icon, "icon": pipeline_template.icon,
"chunk_structure": dataset.chunk_structure, "chunk_structure": pipeline_template.chunk_structure,
"export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline), "export_data": yaml.safe_load(pipeline_template.yaml_content),
} }

@ -1,4 +1,5 @@
from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.customized.customized_retrieval import CustomizedPipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
@ -12,7 +13,7 @@ class PipelineTemplateRetrievalFactory:
case PipelineTemplateType.REMOTE: case PipelineTemplateType.REMOTE:
return RemotePipelineTemplateRetrieval return RemotePipelineTemplateRetrieval
case PipelineTemplateType.CUSTOMIZED: case PipelineTemplateType.CUSTOMIZED:
return DatabasePipelineTemplateRetrieval return CustomizedPipelineTemplateRetrieval
case PipelineTemplateType.DATABASE: case PipelineTemplateType.DATABASE:
return DatabasePipelineTemplateRetrieval return DatabasePipelineTemplateRetrieval
case PipelineTemplateType.BUILTIN: case PipelineTemplateType.BUILTIN:

@ -7,7 +7,7 @@ from typing import Any, Optional, cast
from uuid import uuid4 from uuid import uuid4
from flask_login import current_user from flask_login import current_user
from sqlalchemy import select from sqlalchemy import or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
import contexts import contexts
@ -47,16 +47,19 @@ from models.workflow import (
WorkflowType, WorkflowType,
) )
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
PipelineTemplateInfoEntity,
)
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
class RagPipelineService: class RagPipelineService:
@staticmethod @classmethod
def get_pipeline_templates( def get_pipeline_templates(
type: str = "built-in", language: str = "en-US" cls, type: str = "built-in", language: str = "en-US"
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: ) -> dict:
if type == "built-in": if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
@ -64,12 +67,12 @@ class RagPipelineService:
if not result.get("pipeline_templates") and language != "en-US": if not result.get("pipeline_templates") and language != "en-US":
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US") result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])] return result
else: else:
mode = "customized" mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language) result = retrieval_instance.get_pipeline_templates(language)
return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])] return result
@classmethod @classmethod
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
@ -684,7 +687,10 @@ class RagPipelineService:
base_query = db.session.query(WorkflowRun).filter( base_query = db.session.query(WorkflowRun).filter(
WorkflowRun.tenant_id == pipeline.tenant_id, WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id, WorkflowRun.app_id == pipeline.id,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, or_(
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value
)
) )
if args.get("last_id"): if args.get("last_id"):
@ -765,8 +771,26 @@ class RagPipelineService:
# Use the repository to get the node executions with ordering # Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc") order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) node_executions = repository.get_by_workflow_run(workflow_run_id=run_id,
order_config=order_config,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN)
# Convert domain models to database models # Convert domain models to database models
workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions]
return workflow_node_executions return workflow_node_executions
@classmethod
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError("Workflow not found")
db.session.commit()

@ -1,5 +1,7 @@
import base64 import base64
from datetime import UTC, datetime
import hashlib import hashlib
import json
import logging import logging
import uuid import uuid
from collections.abc import Mapping from collections.abc import Mapping
@ -31,13 +33,12 @@ from extensions.ext_redis import redis_client
from factories import variable_factory from factories import variable_factory
from models import Account from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.workflow import Workflow from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import ( from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration, KnowledgeConfiguration,
RagPipelineDatasetCreateEntity, RagPipelineDatasetCreateEntity,
) )
from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -206,12 +207,12 @@ class RagPipelineDslService:
status = _check_version_compatibility(imported_version) status = _check_version_compatibility(imported_version)
# Extract app data # Extract app data
pipeline_data = data.get("pipeline") pipeline_data = data.get("rag_pipeline")
if not pipeline_data: if not pipeline_data:
return RagPipelineImportInfo( return RagPipelineImportInfo(
id=import_id, id=import_id,
status=ImportStatus.FAILED, status=ImportStatus.FAILED,
error="Missing pipeline data in YAML content", error="Missing rag_pipeline data in YAML content",
) )
# If app_id is provided, check if it exists # If app_id is provided, check if it exists
@ -256,7 +257,7 @@ class RagPipelineDslService:
if dependencies: if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
# Create or update app # Create or update pipeline
pipeline = self._create_or_update_pipeline( pipeline = self._create_or_update_pipeline(
pipeline=pipeline, pipeline=pipeline,
data=data, data=data,
@ -278,7 +279,9 @@ class RagPipelineDslService:
if node.get("data", {}).get("type") == "knowledge_index": if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
if not dataset: if dataset and pipeline.is_published and dataset.chunk_structure != knowledge_configuration.chunk_structure:
raise ValueError("Chunk structure is not compatible with the published pipeline")
else:
dataset = Dataset( dataset = Dataset(
tenant_id=account.current_tenant_id, tenant_id=account.current_tenant_id,
name=name, name=name,
@ -295,11 +298,6 @@ class RagPipelineDslService:
runtime_mode="rag_pipeline", runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure, chunk_structure=knowledge_configuration.chunk_structure,
) )
else:
dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality": if knowledge_configuration.index_method.indexing_technique == "high_quality":
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
@ -540,33 +538,6 @@ class RagPipelineDslService:
icon_type = "emoji" icon_type = "emoji"
icon = str(pipeline_data.get("icon", "")) icon = str(pipeline_data.get("icon", ""))
if pipeline:
# Update existing pipeline
pipeline.name = pipeline_data.get("name", pipeline.name)
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.updated_by = account.id
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
# Create new app
pipeline = Pipeline()
pipeline.id = str(uuid4())
pipeline.tenant_id = account.current_tenant_id
pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "")
pipeline.created_by = account.id
pipeline.updated_by = account.id
self._session.add(pipeline)
self._session.commit()
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
)
# Initialize pipeline based on mode # Initialize pipeline based on mode
workflow_data = data.get("workflow") workflow_data = data.get("workflow")
@ -583,12 +554,7 @@ class RagPipelineDslService:
] ]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", []) rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
rag_pipeline_service = RagPipelineService()
current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {}) graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []): for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
@ -599,19 +565,77 @@ class RagPipelineDslService:
if ( if (
decrypted_id := self.decrypt_dataset_id( decrypted_id := self.decrypt_dataset_id(
encrypted_data=dataset_id, encrypted_data=dataset_id,
tenant_id=pipeline.tenant_id, tenant_id=account.current_tenant_id,
) )
) )
] ]
rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline, if pipeline:
graph=workflow_data.get("graph", {}), # Update existing pipeline
unique_hash=unique_hash, pipeline.name = pipeline_data.get("name", pipeline.name)
account=account, pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.updated_by = account.id
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
# Create new app
pipeline = Pipeline()
pipeline.id = str(uuid4())
pipeline.tenant_id = account.current_tenant_id
pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "")
pipeline.created_by = account.id
pipeline.updated_by = account.id
self._session.add(pipeline)
self._session.commit()
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
)
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
)
# create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list, rag_pipeline_variables=rag_pipeline_variables_list,
) )
db.session.add(workflow)
db.session.flush()
pipeline.workflow_id = workflow.id
else:
workflow.graph = json.dumps(graph)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
workflow.rag_pipeline_variables = rag_pipeline_variables_list
# commit db session changes
db.session.commit()
return pipeline return pipeline
@ -623,16 +647,19 @@ class RagPipelineDslService:
:param include_secret: Whether include secret variable :param include_secret: Whether include secret variable
:return: :return:
""" """
dataset = pipeline.dataset
if not dataset:
raise ValueError("Missing dataset for rag pipeline")
icon_info = dataset.icon_info
export_data = { export_data = {
"version": CURRENT_DSL_VERSION, "version": CURRENT_DSL_VERSION,
"kind": "rag_pipeline", "kind": "rag_pipeline",
"pipeline": { "pipeline": {
"name": pipeline.name, "name": pipeline.name,
"mode": pipeline.mode, "icon": icon_info.get("icon", "📙") if icon_info else "📙",
"icon": "🤖" if pipeline.icon_type == "image" else pipeline.icon, "icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji",
"icon_background": "#FFEAD5" if pipeline.icon_type == "image" else pipeline.icon_background, "icon_background": icon_info.get("icon_background", "#FFEAD5") if icon_info else "#FFEAD5",
"description": pipeline.description, "description": pipeline.description,
"use_icon_as_answer_icon": pipeline.use_icon_as_answer_icon,
}, },
} }
@ -647,8 +674,16 @@ class RagPipelineDslService:
:param export_data: export data :param export_data: export data
:param pipeline: Pipeline instance :param pipeline: Pipeline instance
""" """
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
)
if not workflow: if not workflow:
raise ValueError("Missing draft workflow configuration, please check.") raise ValueError("Missing draft workflow configuration, please check.")
@ -855,14 +890,6 @@ class RagPipelineDslService:
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
) )
dataset = Dataset(
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
)
with Session(db.engine) as session: with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session) rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user) account = cast(Account, current_user)
@ -870,11 +897,11 @@ class RagPipelineDslService:
account=account, account=account,
import_mode=ImportMode.YAML_CONTENT.value, import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content, yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset, dataset=None,
) )
return { return {
"id": rag_pipeline_import_info.id, "id": rag_pipeline_import_info.id,
"dataset_id": dataset.id, "dataset_id": rag_pipeline_import_info.dataset_id,
"pipeline_id": rag_pipeline_import_info.pipeline_id, "pipeline_id": rag_pipeline_import_info.pipeline_id,
"status": rag_pipeline_import_info.status, "status": rag_pipeline_import_info.status,
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version, "imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,

Loading…
Cancel
Save