diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 3c10205927..99dae3cfc7 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -947,7 +947,8 @@ class RagPipelineWorkflowLastRunApi(Resource): if node_exec is None: raise NotFound("last run not found") return node_exec - + + class RagPipelineTransformApi(Resource): @setup_required @login_required @@ -955,8 +956,8 @@ class RagPipelineTransformApi(Resource): def post(self, dataset_id): dataset_id = str(dataset_id) rag_pipeline_transform_service = RagPipelineTransformService() - rag_pipeline_transform_service.transform_dataset(dataset_id) - return {"message": "success"} + result = rag_pipeline_transform_service.transform_dataset(dataset_id) + return result api.add_resource( @@ -1070,4 +1071,4 @@ api.add_resource( api.add_resource( RagPipelineTransformApi, "/rag/pipelines/transform/datasets/", -) \ No newline at end of file +) diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 76139c4ebe..d2eec72818 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,4 +1,5 @@ import re + from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity from models.workflow import Workflow @@ -56,7 +57,7 @@ class WorkflowVariablesConfigManager: last_part = full_path.split(".")[-1] variables_map.pop(last_part) all_second_step_variables = list(variables_map.values()) - + for item in all_second_step_variables: if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared": variables.append(RagPipelineVariableEntity.model_validate(item)) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 200fd68bac..bbb043aacd 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -171,43 +171,45 @@ class DatasourceProviderService: } for option in credential.options or [] ], - } for credential in datasource.declaration.credentials_schema + } + for credential in datasource.declaration.credentials_schema ], - "oauth_schema": - { - "client_schema": [ - { - "type": client_schema.type.value, - "name": client_schema.name, - "required": client_schema.required, - "default": client_schema.default, - "options": [ - { - "value": option.value, - "label": option.label.model_dump(), - } - for option in client_schema.options or [] - ], - } - for client_schema in datasource.declaration.oauth_schema.client_schema or [] - ], - "credentials_schema": [ - { - "type": credential.type.value, - "name": credential.name, - "required": credential.required, - "default": credential.default, - "options": [ - { - "value": option.value, - "label": option.label.model_dump(), - } - for option in credential.options or [] - ], - } - for credential in datasource.declaration.oauth_schema.credentials_schema or [] - ], - } if datasource.declaration.oauth_schema else None, + "oauth_schema": { + "client_schema": [ + { + "type": client_schema.type.value, + "name": client_schema.name, + "required": client_schema.required, + "default": client_schema.default, + "options": [ + { + "value": option.value, + "label": option.label.model_dump(), + } + for option in client_schema.options or [] + ], + } + for client_schema in datasource.declaration.oauth_schema.client_schema or [] + ], + "credentials_schema": [ + { + "type": credential.type.value, + "name": credential.name, + "required": credential.required, + "default": credential.default, + "options": [ + { + "value": option.value, + "label": option.label.model_dump(), + } + for option in credential.options or [] + ], + } + for credential in datasource.declaration.oauth_schema.credentials_schema or [] + ], + } + if datasource.declaration.oauth_schema + else None, } ) return datasource_credentials diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 4c59610e79..857f20d84d 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -54,7 +54,7 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account -from models.dataset import Dataset, Document, Pipeline, PipelineCustomizedTemplate # type: ignore +from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore from models.enums import WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 4fce28990b..cde1e9f182 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -15,8 +15,6 @@ from services.entities.knowledge_entities.rag_pipeline_entities import Knowledge class RagPipelineTransformService: - - def transform_dataset(self, dataset_id: str): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: @@ -42,7 +40,10 @@ class RagPipelineTransformService: new_nodes = [] for node in nodes: - if node.get("data", {}).get("type") == "datasource" and node.get("data", {}).get("provider_type") == "local_file": + if ( + node.get("data", {}).get("type") == "datasource" + and node.get("data", {}).get("provider_type") == "local_file" + ): node = self._deal_file_extensions(node) if node.get("data", {}).get("type") == "knowledge-index": node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node) @@ -66,6 +67,11 @@ class RagPipelineTransformService: dataset.pipeline_id = pipeline.id db.session.commit() + return { + "pipeline_id": pipeline.id, + "dataset_id": dataset_id, + "status": "success", + } def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str): if doc_form == "text_model": @@ -73,29 +79,29 @@ class RagPipelineTransformService: case "upload_file": if indexing_technique == "high_quality": # get graph from transform.file-general-high-quality.yml - with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml", "r") as f: + with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) if indexing_technique == "economy": # get graph from transform.file-general-economy.yml - with open(f"{Path(__file__).parent}/transform/file-general-economy.yml", "r") as f: + with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case "notion_import": if indexing_technique == "high_quality": # get graph from transform.notion-general-high-quality.yml - with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml", "r") as f: + with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) if indexing_technique == "economy": # get graph from transform.notion-general-economy.yml - with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml", "r") as f: + with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case "website_crawl": if indexing_technique == "high_quality": # get graph from transform.website-crawl-general-high-quality.yml - with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml", "r") as f: + with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) if indexing_technique == "economy": # get graph from transform.website-crawl-general-economy.yml - with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml", "r") as f: + with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case _: raise ValueError("Unsupported datasource type") @@ -103,15 +109,15 @@ class RagPipelineTransformService: match datasource_type: case "upload_file": # get graph from transform.file-parent-child.yml - with open(f"{Path(__file__).parent}/transform/file-parent-child.yml", "r") as f: + with open(f"{Path(__file__).parent}/transform/file-parent-child.yml") as f: pipeline_yaml = yaml.safe_load(f) case "notion_import": # get graph from transform.notion-parent-child.yml - with open(f"{Path(__file__).parent}/transform/notion-parent-child.yml", "r") as f: + with open(f"{Path(__file__).parent}/transform/notion-parent-child.yml") as f: pipeline_yaml = yaml.safe_load(f) case "website_crawl": # get graph from transform.website-crawl-parent-child.yml - with open(f"{Path(__file__).parent}/transform/website-crawl-parent-child.yml", "r") as f: + with open(f"{Path(__file__).parent}/transform/website-crawl-parent-child.yml") as f: pipeline_yaml = yaml.safe_load(f) case _: raise ValueError("Unsupported datasource type") @@ -127,7 +133,9 @@ class RagPipelineTransformService: node["data"]["fileExtensions"] = DOCUMENT_EXTENSIONS return node - def _deal_knowledge_index(self, dataset: Dataset, doc_form: str, indexing_technique: str, retrieval_model: dict, node: dict): + def _deal_knowledge_index( + self, dataset: Dataset, doc_form: str, indexing_technique: str, retrieval_model: dict, node: dict + ): knowledge_configuration = node.get("data", {}) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)