feat/datasource
jyong 1 year ago
parent 70d2c78176
commit 6d547447d3

@ -161,7 +161,7 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
args = parser.parse_args() args = parser.parse_args()
dataset = DatasetService.create_empty_rag_pipeline_dataset( dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=args, rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args),
) )
return marshal(dataset, dataset_detail_fields), 201 return marshal(dataset, dataset_detail_fields), 201

@ -8,7 +8,6 @@ from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from models.model import EndUser
import services import services
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import api
@ -40,6 +39,7 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.account import Account from models.account import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from models.model import EndUser
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
@ -242,7 +242,7 @@ class DraftRagPipelineRunApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=list, required=True, location="json") parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json") parser.add_argument("start_node_id", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -320,6 +320,9 @@ class RagPipelineDatasourceNodeRunApi(Resource):
inputs = args.get("inputs") inputs = args.get("inputs")
if inputs == None: if inputs == None:
raise ValueError("missing inputs") raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type == None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.run_datasource_workflow_node( result = rag_pipeline_service.run_datasource_workflow_node(
@ -327,7 +330,7 @@ class RagPipelineDatasourceNodeRunApi(Resource):
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=inputs,
account=current_user, account=current_user,
datasource_type=args.get("datasource_type"), datasource_type=datasource_type,
) )
return result return result

@ -32,6 +32,7 @@ from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerat
from extensions.ext_database import db from extensions.ext_database import db
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline from models.dataset import Document, Pipeline
from models.model import AppMode
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -91,7 +92,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool = True, streaming: bool = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# convert to app config # convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config( pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, pipeline=pipeline,
@ -107,19 +108,23 @@ class PipelineGenerator(BaseAppGenerator):
for datasource_info in datasource_info_list: for datasource_info in datasource_info_list:
workflow_run_id = str(uuid.uuid4()) workflow_run_id = str(uuid.uuid4())
document_id = None document_id = None
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
if invoke_from == InvokeFrom.PUBLISHED: if invoke_from == InvokeFrom.PUBLISHED:
position = DocumentService.get_documents_position(pipeline.dataset_id)
position = DocumentService.get_documents_position(pipeline.dataset_id) position = DocumentService.get_documents_position(pipeline.dataset_id)
document = self._build_document( document = self._build_document(
tenant_id=pipeline.tenant_id, tenant_id=pipeline.tenant_id,
dataset_id=pipeline.dataset_id, dataset_id=pipeline.dataset_id,
built_in_field_enabled=pipeline.dataset.built_in_field_enabled, built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type, datasource_type=datasource_type,
datasource_info=datasource_info, datasource_info=datasource_info,
created_from="rag-pipeline", created_from="rag-pipeline",
position=position, position=position,
account=user, account=user,
batch=batch, batch=batch,
document_form=pipeline.dataset.chunk_structure, document_form=dataset.chunk_structure,
) )
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
@ -127,10 +132,12 @@ class PipelineGenerator(BaseAppGenerator):
# init application generate entity # init application generate entity
application_generate_entity = RagPipelineGenerateEntity( application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
pipline_config=pipeline_config, app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=datasource_type, datasource_type=datasource_type,
datasource_info=datasource_info, datasource_info=datasource_info,
dataset_id=pipeline.dataset_id, dataset_id=dataset.id,
start_node_id=start_node_id,
batch=batch, batch=batch,
document_id=document_id, document_id=document_id,
inputs=self._prepare_user_inputs( inputs=self._prepare_user_inputs(
@ -160,7 +167,7 @@ class PipelineGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) )
if invoke_from == InvokeFrom.DEBUGGER:
return self._generate( return self._generate(
pipeline=pipeline, pipeline=pipeline,
workflow=workflow, workflow=workflow,
@ -171,6 +178,17 @@ class PipelineGenerator(BaseAppGenerator):
streaming=streaming, streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id, workflow_thread_pool_id=workflow_thread_pool_id,
) )
else:
self._generate(
pipeline=pipeline,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
def _generate( def _generate(
self, self,
@ -201,7 +219,7 @@ class PipelineGenerator(BaseAppGenerator):
task_id=application_generate_entity.task_id, task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
app_mode=pipeline.mode, app_mode=AppMode.RAG_PIPELINE,
) )
# new thread # new thread
@ -256,12 +274,18 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required") raise ValueError("inputs is required")
# convert to app config # convert to app config
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
# init application generate entity # init application generate entity
application_generate_entity = WorkflowAppGenerateEntity( application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args["datasource_type"],
datasource_info=args["datasource_info"],
dataset_id=pipeline.dataset_id,
batch=args["batch"],
document_id=args["document_id"],
inputs={}, inputs={},
files=[], files=[],
user_id=user.id, user_id=user.id,
@ -288,7 +312,7 @@ class PipelineGenerator(BaseAppGenerator):
) )
return self._generate( return self._generate(
app_model=app_model, pipeline=pipeline,
workflow=workflow, workflow=workflow,
user=user, user=user,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
@ -299,7 +323,7 @@ class PipelineGenerator(BaseAppGenerator):
def single_loop_generate( def single_loop_generate(
self, self,
app_model: App, pipeline: Pipeline,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user: Account | EndUser, user: Account | EndUser,
@ -323,7 +347,7 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required") raise ValueError("inputs is required")
# convert to app config # convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow)
# init application generate entity # init application generate entity
application_generate_entity = WorkflowAppGenerateEntity( application_generate_entity = WorkflowAppGenerateEntity(
@ -353,7 +377,7 @@ class PipelineGenerator(BaseAppGenerator):
) )
return self._generate( return self._generate(
app_model=app_model, pipeline=pipeline,
workflow=workflow, workflow=workflow,
user=user, user=user,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,

@ -1,5 +1,6 @@
import logging import logging
from typing import Optional, cast from collections.abc import Mapping
from typing import Any, Optional, cast
from configs import dify_config from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
@ -12,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Pipeline from models.dataset import Pipeline
@ -100,6 +102,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id, SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
SystemVariableKey.BATCH: self.application_generate_entity.batch, SystemVariableKey.BATCH: self.application_generate_entity.batch,
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
} }
variable_pool = VariablePool( variable_pool = VariablePool(
@ -110,7 +114,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
) )
# init graph # init graph
graph = self._init_graph(graph_config=workflow.graph_dict) graph = self._init_rag_pipeline_graph(
graph_config=workflow.graph_dict,
start_node_id=self.application_generate_entity.start_node_id,
)
# RUN WORKFLOW # RUN WORKFLOW
workflow_entry = WorkflowEntry( workflow_entry = WorkflowEntry(
@ -152,3 +159,43 @@ class PipelineRunner(WorkflowBasedAppRunner):
# return workflow # return workflow
return workflow return workflow
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
"""
Init pipeline graph
"""
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
nodes = graph_config.get("nodes", [])
edges = graph_config.get("edges", [])
real_run_nodes = []
real_edges = []
exclude_node_ids = []
for node in nodes:
node_id = node.get("id")
node_type = node.get("data", {}).get("type", "")
if node_type == "datasource":
if start_node_id != node_id:
exclude_node_ids.append(node_id)
continue
real_run_nodes.append(node)
for edge in edges:
if edge.get("source") in exclude_node_ids :
continue
real_edges.append(edge)
graph_config = dict(graph_config)
graph_config["nodes"] = real_run_nodes
graph_config["edges"] = real_edges
# init graph
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError("graph not found in workflow")
return graph

@ -233,14 +233,14 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
""" """
RAG Pipeline Application Generate Entity. RAG Pipeline Application Generate Entity.
""" """
# pipeline config
# app config pipeline_config: WorkflowUIBasedAppConfig
pipline_config: WorkflowUIBasedAppConfig
datasource_type: str datasource_type: str
datasource_info: Mapping[str, Any] datasource_info: Mapping[str, Any]
dataset_id: str dataset_id: str
batch: str batch: str
document_id: str document_id: Optional[str] = None
start_node_id: Optional[str] = None
class SingleIterationRunEntity(BaseModel): class SingleIterationRunEntity(BaseModel):
""" """

@ -18,3 +18,5 @@ class SystemVariableKey(StrEnum):
DOCUMENT_ID = "document_id" DOCUMENT_ID = "document_id"
BATCH = "batch" BATCH = "batch"
DATASET_ID = "dataset_id" DATASET_ID = "dataset_id"
DATASOURCE_TYPE = "datasource_type"
DATASOURCE_INFO = "datasource_info"

@ -121,6 +121,8 @@ class Graph(BaseModel):
# fetch nodes that have no predecessor node # fetch nodes that have no predecessor node
root_node_configs = [] root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {} all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs: for node_config in node_configs:
node_id = node_config.get("id") node_id = node_config.get("id")
if not node_id: if not node_id:
@ -141,6 +143,7 @@ class Graph(BaseModel):
node_config.get("id") node_config.get("id")
for node_config in root_node_configs for node_config in root_node_configs
if node_config.get("data", {}).get("type", "") == NodeType.START.value if node_config.get("data", {}).get("type", "") == NodeType.START.value
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
), ),
None, None,
) )

@ -6,11 +6,8 @@ from core.datasource.entities.datasource_entities import (
DatasourceProviderType, DatasourceProviderType,
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse, GetOnlineDocumentPageContentResponse,
GetWebsiteCrawlRequest,
GetWebsiteCrawlResponse,
) )
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.file import File from core.file import File
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment from core.variables.segments import ArrayAnySegment
@ -42,22 +39,23 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
""" """
node_data = cast(DatasourceNodeData, self.node_data) node_data = cast(DatasourceNodeData, self.node_data)
variable_pool = self.graph_runtime_state.variable_pool
# fetch datasource icon
datasource_info = {
"provider_id": node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier,
}
# get datasource runtime # get datasource runtime
try: try:
from core.datasource.datasource_manager import DatasourceManager from core.datasource.datasource_manager import DatasourceManager
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_runtime = DatasourceManager.get_datasource_runtime( datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=node_data.provider_id, provider_id=node_data.provider_id,
datasource_name=node_data.datasource_name, datasource_name=node_data.datasource_name,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType(node_data.provider_type), datasource_type=DatasourceProviderType(datasource_type),
) )
except DatasourceNodeError as e: except DatasourceNodeError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
@ -75,12 +73,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
datasource_parameters = datasource_runtime.entity.parameters datasource_parameters = datasource_runtime.entity.parameters
parameters = self._generate_parameters( parameters = self._generate_parameters(
datasource_parameters=datasource_parameters, datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=variable_pool,
node_data=self.node_data, node_data=self.node_data,
) )
parameters_for_log = self._generate_parameters( parameters_for_log = self._generate_parameters(
datasource_parameters=datasource_parameters, datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=variable_pool,
node_data=self.node_data, node_data=self.node_data,
for_log=True, for_log=True,
) )
@ -106,20 +104,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
}, },
) )
) )
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL: elif (
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) datasource_runtime.datasource_provider_type in (
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( DatasourceProviderType.WEBSITE_CRAWL,
user_id=self.user_id, DatasourceProviderType.LOCAL_FILE,
datasource_parameters=GetWebsiteCrawlRequest(**parameters),
provider_type=datasource_runtime.datasource_provider_type(),
) )
):
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={ outputs={
"website": website_crawl_result.result.model_dump(), "website": datasource_info,
"datasource_type": datasource_runtime.datasource_provider_type, "datasource_type": datasource_runtime.datasource_provider_type,
}, },
) )

@ -6,7 +6,7 @@ import random
import time import time
import uuid import uuid
from collections import Counter from collections import Counter
from typing import Any, Optional from typing import Any, Optional, cast
from flask_login import current_user from flask_login import current_user
from sqlalchemy import func, select from sqlalchemy import func, select
@ -298,13 +298,14 @@ class DatasetService:
description=rag_pipeline_dataset_create_entity.description, description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission, permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor", provider="vendor",
runtime_mode="rag_pipeline", runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info, icon_info=rag_pipeline_dataset_create_entity.icon_info,
) )
with Session(db.engine) as session: with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session) rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=current_user, account=account,
import_mode=ImportMode.YAML_CONTENT.value, import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content, yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset, dataset=dataset,

@ -59,12 +59,12 @@ class RagPipelineService:
if not result.get("pipeline_templates") and language != "en-US": if not result.get("pipeline_templates") and language != "en-US":
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US") result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
return result.get("pipeline_templates") return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])]
else: else:
mode = "customized" mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language) result = retrieval_instance.get_pipeline_templates(language)
return result.get("pipeline_templates") return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])]
@classmethod @classmethod
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:

@ -97,11 +97,6 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus:
class RagPipelinePendingData(BaseModel): class RagPipelinePendingData(BaseModel):
import_mode: str import_mode: str
yaml_content: str yaml_content: str
name: str | None
description: str | None
icon_type: str | None
icon: str | None
icon_background: str | None
pipeline_id: str | None pipeline_id: str | None
@ -302,10 +297,6 @@ class RagPipelineDslService:
dataset.runtime_mode = "rag_pipeline" dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality": if knowledge_configuration.index_method.indexing_technique == "high_quality":
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
)
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter( .filter(
@ -445,10 +436,28 @@ class RagPipelineDslService:
dataset.runtime_mode = "rag_pipeline" dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality": if knowledge_configuration.index_method.indexing_technique == "high_quality":
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = (
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore db.session.query(DatasetCollectionBinding)
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore .filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
DatasetCollectionBinding.model_name
== knowledge_configuration.index_method.embedding_setting.embedding_model_name,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
.first()
) )
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
db.session.add(dataset_collection_binding)
db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = ( dataset.embedding_model = (
@ -602,7 +611,6 @@ class RagPipelineDslService:
rag_pipeline_service.sync_draft_workflow( rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline, pipeline=pipeline,
graph=workflow_data.get("graph", {}), graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash, unique_hash=unique_hash,
account=account, account=account,
environment_variables=environment_variables, environment_variables=environment_variables,

Loading…
Cancel
Save