feat/r2
jyong 10 months ago
parent cdbba1400c
commit f44f0fa34c

@ -804,7 +804,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
if not node_id: if not node_id:
raise ValueError("Node ID is required") raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_published_second_step_parameters(pipeline=pipeline, node_id=node_id) variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return { return {
"variables": variables, "variables": variables,
} }
@ -829,7 +829,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
if not node_id: if not node_id:
raise ValueError("Node ID is required") raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_published_first_step_parameters(pipeline=pipeline, node_id=node_id) variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return { return {
"variables": variables, "variables": variables,
} }
@ -854,7 +854,7 @@ class DraftRagPipelineFirstStepApi(Resource):
if not node_id: if not node_id:
raise ValueError("Node ID is required") raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_draft_first_step_parameters(pipeline=pipeline, node_id=node_id) variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
return { return {
"variables": variables, "variables": variables,
} }
@ -880,7 +880,7 @@ class DraftRagPipelineSecondStepApi(Resource):
raise ValueError("Node ID is required") raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_draft_second_step_parameters(pipeline=pipeline, node_id=node_id) variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
return { return {
"variables": variables, "variables": variables,
} }

@ -113,6 +113,14 @@ class RagPipelineService:
) )
if not customized_template: if not customized_template:
raise ValueError("Customized pipeline template not found.") raise ValueError("Customized pipeline template not found.")
# check template name is exist
template_name = template_info.name
if template_name:
template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id).first()
if template:
raise ValueError("Template name is already exists")
customized_template.name = template_info.name customized_template.name = template_info.name
customized_template.description = template_info.description customized_template.description = template_info.description
customized_template.icon = template_info.icon_info.model_dump() customized_template.icon = template_info.icon_info.model_dump()
@ -785,7 +793,7 @@ class RagPipelineService:
break break
if not datasource_node_data: if not datasource_node_data:
raise ValueError("Datasource node data not found") raise ValueError("Datasource node data not found")
variables = datasource_node_data.get("variables", {}) variables = published_workflow.rag_pipeline_variables
if variables: if variables:
variables_map = {item["variable"]: item for item in variables} variables_map = {item["variable"]: item for item in variables}
else: else:
@ -793,29 +801,29 @@ class RagPipelineService:
datasource_parameters = datasource_node_data.get("datasource_parameters", {}) datasource_parameters = datasource_node_data.get("datasource_parameters", {})
user_input_variables = [] user_input_variables = []
for key, value in datasource_parameters.items(): for key, value in datasource_parameters.items():
if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): if value.get("value") and isinstance(value.get("value"), str):
user_input_variables.append(variables_map.get(key, {})) if re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]):
user_input_variables.append(variables_map.get(key, {}))
return user_input_variables return user_input_variables
def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
draft_workflow = self.get_draft_workflow(pipeline=pipeline) workflow = self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
if not draft_workflow: if not workflow:
raise ValueError("Workflow not initialized") raise ValueError("Workflow not initialized")
# get second step node
datasource_node_data = None datasource_node_data = None
datasource_nodes = draft_workflow.graph_dict.get("nodes", []) datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes: for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id: if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {}) datasource_node_data = datasource_node.get("data", {})
break break
if not datasource_node_data: if not datasource_node_data:
raise ValueError("Datasource node data not found") raise ValueError("Datasource node data not found")
variables = datasource_node_data.get("variables", {}) variables = workflow.rag_pipeline_variables
if variables: if variables:
variables_map = {item["variable"]: item for item in variables} variables_map = {item["variable"]: item for item in variables}
else: else:
@ -824,16 +832,21 @@ class RagPipelineService:
user_input_variables = [] user_input_variables = []
for key, value in datasource_parameters.items(): for key, value in datasource_parameters.items():
if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): if value.get("value") and isinstance(value.get("value"), str):
user_input_variables.append(variables_map.get(key, {})) pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split('.')[-1]
user_input_variables.append(variables_map.get(last_part, {}))
return user_input_variables return user_input_variables
def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
workflow = self.get_draft_workflow(pipeline=pipeline) workflow = self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline)
if not workflow: if not workflow:
raise ValueError("Workflow not initialized") raise ValueError("Workflow not initialized")
@ -841,13 +854,32 @@ class RagPipelineService:
rag_pipeline_variables = workflow.rag_pipeline_variables rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables: if not rag_pipeline_variables:
return [] return []
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
# get datasource provider # get datasource node data
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if datasource_node_data:
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split('.')[-1]
variables_map.pop(last_part)
all_second_step_variables = list(variables_map.values())
datasource_provider_variables = [ datasource_provider_variables = [
item item
for item in rag_pipeline_variables for item in all_second_step_variables
if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared"
] ]
return datasource_provider_variables return datasource_provider_variables
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
@ -968,6 +1000,16 @@ class RagPipelineService:
dataset = pipeline.dataset dataset = pipeline.dataset
if not dataset: if not dataset:
raise ValueError("Dataset not found") raise ValueError("Dataset not found")
# check template name is exist
template_name = args.get("name")
if template_name:
template = db.session.query(PipelineCustomizedTemplate).filter(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
).first()
if template:
raise ValueError("Template name is already exists")
max_position = ( max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position)) db.session.query(func.max(PipelineCustomizedTemplate.position))

Loading…
Cancel
Save