From f44f0fa34cf61a8e8cee8efe7eb1f350567a924b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Tue, 1 Jul 2025 14:23:46 +0800 Subject: [PATCH] r2 --- .../rag_pipeline/rag_pipeline_workflow.py | 8 +- api/services/rag_pipeline/rag_pipeline.py | 78 ++++++++++++++----- 2 files changed, 64 insertions(+), 22 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 28ab4b1635..3ef0c42d0f 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -804,7 +804,7 @@ class PublishedRagPipelineSecondStepApi(Resource): if not node_id: raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() - variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id) + variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) return { "variables": variables, } @@ -829,7 +829,7 @@ class PublishedRagPipelineFirstStepApi(Resource): if not node_id: raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() - variables = rag_pipeline_service.get_published_first_step_parameters(pipeline=pipeline, node_id=node_id) + variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) return { "variables": variables, } @@ -854,7 +854,7 @@ class DraftRagPipelineFirstStepApi(Resource): if not node_id: raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() - variables = rag_pipeline_service.get_draft_first_step_parameters(pipeline=pipeline, node_id=node_id) + variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) return { "variables": variables, } @@ -880,7 +880,7 @@ class DraftRagPipelineSecondStepApi(Resource): raise ValueError("Node ID is required") rag_pipeline_service = RagPipelineService() - variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id) + variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) return { "variables": variables, } diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 26036dc2c5..f379a4b930 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -113,6 +113,14 @@ class RagPipelineService: ) if not customized_template: raise ValueError("Customized pipeline template not found.") + # check template name is exist + template_name = template_info.name + if template_name: + template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.name == template_name, + PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + PipelineCustomizedTemplate.id != template_id).first() + if template: + raise ValueError("Template name is already exists") customized_template.name = template_info.name customized_template.description = template_info.description customized_template.icon = template_info.icon_info.model_dump() @@ -785,7 +793,7 @@ class RagPipelineService: break if not datasource_node_data: raise ValueError("Datasource node data not found") - variables = datasource_node_data.get("variables", {}) + variables = published_workflow.rag_pipeline_variables if variables: variables_map = {item["variable"]: item for item in variables} else: @@ -793,29 +801,29 @@ class RagPipelineService: datasource_parameters = datasource_node_data.get("datasource_parameters", {}) user_input_variables = [] for key, value in datasource_parameters.items(): - if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(variables_map.get(key, {})) + if value.get("value") and isinstance(value.get("value"), str): + if re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): + user_input_variables.append(variables_map.get(key, {})) return user_input_variables - def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: """ Get first step parameters of rag pipeline """ - draft_workflow = self.get_draft_workflow(pipeline=pipeline) - if not draft_workflow: + workflow = self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) + if not workflow: raise ValueError("Workflow not initialized") - # get second step node datasource_node_data = None - datasource_nodes = draft_workflow.graph_dict.get("nodes", []) + datasource_nodes = workflow.graph_dict.get("nodes", []) for datasource_node in datasource_nodes: if datasource_node.get("id") == node_id: datasource_node_data = datasource_node.get("data", {}) break if not datasource_node_data: raise ValueError("Datasource node data not found") - variables = datasource_node_data.get("variables", {}) + variables = workflow.rag_pipeline_variables if variables: variables_map = {item["variable"]: item for item in variables} else: @@ -824,16 +832,21 @@ class RagPipelineService: user_input_variables = [] for key, value in datasource_parameters.items(): - if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): - user_input_variables.append(variables_map.get(key, {})) + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split('.')[-1] + user_input_variables.append(variables_map.get(last_part, {})) return user_input_variables - def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: """ Get second step parameters of rag pipeline """ - workflow = self.get_draft_workflow(pipeline=pipeline) + workflow = self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) if not workflow: raise ValueError("Workflow not initialized") @@ -841,13 +854,32 @@ class RagPipelineService: rag_pipeline_variables = workflow.rag_pipeline_variables if not rag_pipeline_variables: return [] + variables_map = {item["variable"]: item for item in rag_pipeline_variables} - # get datasource provider + # get datasource node data + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if datasource_node_data: + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + + for key, value in datasource_parameters.items(): + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split('.')[-1] + variables_map.pop(last_part) + all_second_step_variables = list(variables_map.values()) datasource_provider_variables = [ - item - for item in rag_pipeline_variables - if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" - ] + item + for item in all_second_step_variables + if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" + ] return datasource_provider_variables def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: @@ -968,6 +1000,16 @@ class RagPipelineService: dataset = pipeline.dataset if not dataset: raise ValueError("Dataset not found") + + # check template name is exist + template_name = args.get("name") + if template_name: + template = db.session.query(PipelineCustomizedTemplate).filter( + PipelineCustomizedTemplate.name == template_name, + PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, + ).first() + if template: + raise ValueError("Template name is already exists") max_position = ( db.session.query(func.max(PipelineCustomizedTemplate.position))