diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1f6f2308c0..0e1fad600f 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -20,12 +20,9 @@ from core.datasource.entities.datasource_entities import ( DatasourceProviderType, GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, - OnlineDriveBrowseFilesRequest, - OnlineDriveBrowseFilesResponse, WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin -from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.rag.entities.event import ( BaseDatasourceEvent, @@ -34,9 +31,8 @@ from core.rag.entities.event import ( DatasourceProcessingEvent, ) from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput, Variable +from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -127,6 +123,20 @@ class RagPipelineService: ) if not customized_template: raise ValueError("Customized pipeline template not found.") + # check template name is exist + template_name = template_info.name + if template_name: + template = ( + db.session.query(PipelineCustomizedTemplate) + .filter( + PipelineCustomizedTemplate.name == template_name, + PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + PipelineCustomizedTemplate.id != template_id, + ) + .first() + ) + if template: + raise ValueError("Template name is already exists") customized_template.name = template_info.name customized_template.description = template_info.description customized_template.icon = template_info.icon_info.model_dump() @@ -385,17 +395,6 @@ class RagPipelineService: # run draft workflow node start_at = time.perf_counter() - rag_pipeline_variables = [] - if draft_workflow.rag_pipeline_variables: - for v in draft_workflow.rag_pipeline_variables: - rag_pipeline_variable = RAGPipelineVariable(**v) - if rag_pipeline_variable.variable in user_inputs: - rag_pipeline_variables.append( - RAGPipelineVariableInput( - variable=rag_pipeline_variable, - value=user_inputs[rag_pipeline_variable.variable], - ) - ) workflow_node_execution = self._handle_node_run_result( getter=lambda: WorkflowEntry.single_step_run( @@ -403,12 +402,6 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, - variable_pool=VariablePool( - user_inputs=user_inputs, - environment_variables=draft_workflow.environment_variables, - conversation_variables=draft_workflow.conversation_variables, - rag_pipeline_variables=rag_pipeline_variables, - ), ), start_at=start_at, tenant_id=pipeline.tenant_id, @@ -434,17 +427,6 @@ class RagPipelineService: # run draft workflow node start_at = time.perf_counter() - rag_pipeline_variables = [] - if published_workflow.rag_pipeline_variables: - for v in published_workflow.rag_pipeline_variables: - rag_pipeline_variable = RAGPipelineVariable(**v) - if rag_pipeline_variable.variable in user_inputs: - rag_pipeline_variables.append( - RAGPipelineVariableInput( - variable=rag_pipeline_variable, - value=user_inputs[rag_pipeline_variable.variable], - ) - ) workflow_node_execution = self._handle_node_run_result( getter=lambda: WorkflowEntry.single_step_run( @@ -452,12 +434,6 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, - variable_pool=VariablePool( - user_inputs=user_inputs, - environment_variables=published_workflow.environment_variables, - conversation_variables=published_workflow.conversation_variables, - rag_pipeline_variables=rag_pipeline_variables, - ), ), start_at=start_at, tenant_id=pipeline.tenant_id, @@ -549,35 +525,6 @@ class RagPipelineService: except Exception as e: logger.exception("Error during online document.") yield DatasourceErrorEvent(error=str(e)).model_dump() - case DatasourceProviderType.ONLINE_DRIVE: - datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) - online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = ( - datasource_runtime.online_drive_browse_files( - user_id=account.id, - request=OnlineDriveBrowseFilesRequest( - bucket=user_inputs.get("bucket"), - prefix=user_inputs.get("prefix"), - max_keys=user_inputs.get("max_keys", 20), - start_after=user_inputs.get("start_after"), - ), - provider_type=datasource_runtime.datasource_provider_type(), - ) - ) - start_time = time.time() - start_event = DatasourceProcessingEvent( - total=0, - completed=0, - ) - yield start_event.model_dump() - for message in online_drive_result: - end_time = time.time() - online_drive_event = DatasourceCompletedEvent( - data=message.result, - time_consuming=round(end_time - start_time, 2), - total=None, - completed=None, - ) - yield online_drive_event.model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( @@ -874,77 +821,26 @@ class RagPipelineService: return workflow - def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: - """ - Get second step parameters of rag pipeline - """ - - workflow = self.get_published_workflow(pipeline=pipeline) - if not workflow: - raise ValueError("Workflow not initialized") - - # get second step node - rag_pipeline_variables = workflow.rag_pipeline_variables - if not rag_pipeline_variables: - return [] - - # get datasource provider - datasource_provider_variables = [ - item - for item in rag_pipeline_variables - if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" - ] - return datasource_provider_variables - - def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: """ Get first step parameters of rag pipeline """ - published_workflow = self.get_published_workflow(pipeline=pipeline) - if not published_workflow: - raise ValueError("Workflow not initialized") - - # get second step node - datasource_node_data = None - datasource_nodes = published_workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if not datasource_node_data: - raise ValueError("Datasource node data not found") - variables = datasource_node_data.get("variables", {}) - if variables: - variables_map = {item["variable"]: item for item in variables} - else: - return [] - datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - user_input_variables = [] - for key, value in datasource_parameters.items(): - if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(variables_map.get(key, {})) - return user_input_variables - - def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: - """ - Get first step parameters of rag pipeline - """ - - draft_workflow = self.get_draft_workflow(pipeline=pipeline) - if not draft_workflow: + workflow = ( + self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) + ) + if not workflow: raise ValueError("Workflow not initialized") - # get second step node datasource_node_data = None - datasource_nodes = draft_workflow.graph_dict.get("nodes", []) + datasource_nodes = workflow.graph_dict.get("nodes", []) for datasource_node in datasource_nodes: if datasource_node.get("id") == node_id: datasource_node_data = datasource_node.get("data", {}) break if not datasource_node_data: raise ValueError("Datasource node data not found") - variables = datasource_node_data.get("variables", {}) + variables = workflow.rag_pipeline_variables if variables: variables_map = {item["variable"]: item for item in variables} else: @@ -953,16 +849,23 @@ class RagPipelineService: user_input_variables = [] for key, value in datasource_parameters.items(): - if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(variables_map.get(key, {})) + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split(".")[-1] + user_input_variables.append(variables_map.get(last_part, {})) return user_input_variables - def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: """ Get second step parameters of rag pipeline """ - workflow = self.get_draft_workflow(pipeline=pipeline) + workflow = ( + self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) + ) if not workflow: raise ValueError("Workflow not initialized") @@ -970,11 +873,30 @@ class RagPipelineService: rag_pipeline_variables = workflow.rag_pipeline_variables if not rag_pipeline_variables: return [] + variables_map = {item["variable"]: item for item in rag_pipeline_variables} + + # get datasource node data + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if datasource_node_data: + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - # get datasource provider + for key, value in datasource_parameters.items(): + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split(".")[-1] + variables_map.pop(last_part) + all_second_step_variables = list(variables_map.values()) datasource_provider_variables = [ item - for item in rag_pipeline_variables + for item in all_second_step_variables if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" ] return datasource_provider_variables @@ -1098,6 +1020,20 @@ class RagPipelineService: if not dataset: raise ValueError("Dataset not found") + # check template name is exist + template_name = args.get("name") + if template_name: + template = ( + db.session.query(PipelineCustomizedTemplate) + .filter( + PipelineCustomizedTemplate.name == template_name, + PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, + ) + .first() + ) + if template: + raise ValueError("Template name is already exists") + max_position = ( db.session.query(func.max(PipelineCustomizedTemplate.position)) .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)