r2 transform

feat/rag-2
jyong 7 months ago
parent 2012ea3213
commit 384073f025

@ -947,7 +947,8 @@ class RagPipelineWorkflowLastRunApi(Resource):
if node_exec is None:
raise NotFound("last run not found")
return node_exec
class RagPipelineTransformApi(Resource):
@setup_required
@login_required
@ -955,8 +956,8 @@ class RagPipelineTransformApi(Resource):
def post(self, dataset_id):
dataset_id = str(dataset_id)
rag_pipeline_transform_service = RagPipelineTransformService()
rag_pipeline_transform_service.transform_dataset(dataset_id)
return {"message": "success"}
result = rag_pipeline_transform_service.transform_dataset(dataset_id)
return result
api.add_resource(
@ -1070,4 +1071,4 @@ api.add_resource(
api.add_resource(
RagPipelineTransformApi,
"/rag/pipelines/transform/datasets/<uuid:dataset_id>",
)
)

@ -1,4 +1,5 @@
import re
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from models.workflow import Workflow
@ -56,7 +57,7 @@ class WorkflowVariablesConfigManager:
last_part = full_path.split(".")[-1]
variables_map.pop(last_part)
all_second_step_variables = list(variables_map.values())
for item in all_second_step_variables:
if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared":
variables.append(RagPipelineVariableEntity.model_validate(item))

@ -171,43 +171,45 @@ class DatasourceProviderService:
}
for option in credential.options or []
],
} for credential in datasource.declaration.credentials_schema
}
for credential in datasource.declaration.credentials_schema
],
"oauth_schema":
{
"client_schema": [
{
"type": client_schema.type.value,
"name": client_schema.name,
"required": client_schema.required,
"default": client_schema.default,
"options": [
{
"value": option.value,
"label": option.label.model_dump(),
}
for option in client_schema.options or []
],
}
for client_schema in datasource.declaration.oauth_schema.client_schema or []
],
"credentials_schema": [
{
"type": credential.type.value,
"name": credential.name,
"required": credential.required,
"default": credential.default,
"options": [
{
"value": option.value,
"label": option.label.model_dump(),
}
for option in credential.options or []
],
}
for credential in datasource.declaration.oauth_schema.credentials_schema or []
],
} if datasource.declaration.oauth_schema else None,
"oauth_schema": {
"client_schema": [
{
"type": client_schema.type.value,
"name": client_schema.name,
"required": client_schema.required,
"default": client_schema.default,
"options": [
{
"value": option.value,
"label": option.label.model_dump(),
}
for option in client_schema.options or []
],
}
for client_schema in datasource.declaration.oauth_schema.client_schema or []
],
"credentials_schema": [
{
"type": credential.type.value,
"name": credential.name,
"required": credential.required,
"default": credential.default,
"options": [
{
"value": option.value,
"label": option.label.model_dump(),
}
for option in credential.options or []
],
}
for credential in datasource.declaration.oauth_schema.credentials_schema or []
],
}
if datasource.declaration.oauth_schema
else None,
}
)
return datasource_credentials

@ -54,7 +54,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.dataset import Dataset, Document, Pipeline, PipelineCustomizedTemplate # type: ignore
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
from models.enums import WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (

@ -15,8 +15,6 @@ from services.entities.knowledge_entities.rag_pipeline_entities import Knowledge
class RagPipelineTransformService:
def transform_dataset(self, dataset_id: str):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
@ -42,7 +40,10 @@ class RagPipelineTransformService:
new_nodes = []
for node in nodes:
if node.get("data", {}).get("type") == "datasource" and node.get("data", {}).get("provider_type") == "local_file":
if (
node.get("data", {}).get("type") == "datasource"
and node.get("data", {}).get("provider_type") == "local_file"
):
node = self._deal_file_extensions(node)
if node.get("data", {}).get("type") == "knowledge-index":
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
@ -66,6 +67,11 @@ class RagPipelineTransformService:
dataset.pipeline_id = pipeline.id
db.session.commit()
return {
"pipeline_id": pipeline.id,
"dataset_id": dataset_id,
"status": "success",
}
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str):
if doc_form == "text_model":
@ -73,29 +79,29 @@ class RagPipelineTransformService:
case "upload_file":
if indexing_technique == "high_quality":
# get graph from transform.file-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml", "r") as f:
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
# get graph from transform.file-general-economy.yml
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml", "r") as f:
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "notion_import":
if indexing_technique == "high_quality":
# get graph from transform.notion-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml", "r") as f:
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
# get graph from transform.notion-general-economy.yml
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml", "r") as f:
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "website_crawl":
if indexing_technique == "high_quality":
# get graph from transform.website-crawl-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml", "r") as f:
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
# get graph from transform.website-crawl-general-economy.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml", "r") as f:
with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case _:
raise ValueError("Unsupported datasource type")
@ -103,15 +109,15 @@ class RagPipelineTransformService:
match datasource_type:
case "upload_file":
# get graph from transform.file-parent-child.yml
with open(f"{Path(__file__).parent}/transform/file-parent-child.yml", "r") as f:
with open(f"{Path(__file__).parent}/transform/file-parent-child.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "notion_import":
# get graph from transform.notion-parent-child.yml
with open(f"{Path(__file__).parent}/transform/notion-parent-child.yml", "r") as f:
with open(f"{Path(__file__).parent}/transform/notion-parent-child.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "website_crawl":
# get graph from transform.website-crawl-parent-child.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-parent-child.yml", "r") as f:
with open(f"{Path(__file__).parent}/transform/website-crawl-parent-child.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case _:
raise ValueError("Unsupported datasource type")
@ -127,7 +133,9 @@ class RagPipelineTransformService:
node["data"]["fileExtensions"] = DOCUMENT_EXTENSIONS
return node
def _deal_knowledge_index(self, dataset: Dataset, doc_form: str, indexing_technique: str, retrieval_model: dict, node: dict):
def _deal_knowledge_index(
self, dataset: Dataset, doc_form: str, indexing_technique: str, retrieval_model: dict, node: dict
):
knowledge_configuration = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)

Loading…
Cancel
Save