feat/datasource
jyong 11 months ago
parent 70d2c78176
commit 6d547447d3

@ -161,7 +161,7 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
args = parser.parse_args()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=args,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args),
)
return marshal(dataset, dataset_detail_fields), 201

@ -8,7 +8,6 @@ from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from models.model import EndUser
import services
from configs import dify_config
from controllers.console import api
@ -40,6 +39,7 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Pipeline
from models.model import EndUser
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
@ -242,7 +242,7 @@ class DraftRagPipelineRunApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=list, required=True, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
args = parser.parse_args()
@ -320,6 +320,9 @@ class RagPipelineDatasourceNodeRunApi(Resource):
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type == None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.run_datasource_workflow_node(
@ -327,7 +330,7 @@ class RagPipelineDatasourceNodeRunApi(Resource):
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=args.get("datasource_type"),
datasource_type=datasource_type,
)
return result

@ -32,6 +32,7 @@ from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerat
from extensions.ext_database import db
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, Pipeline
from models.model import AppMode
from services.dataset_service import DocumentService
logger = logging.getLogger(__name__)
@ -91,7 +92,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
@ -107,19 +108,23 @@ class PipelineGenerator(BaseAppGenerator):
for datasource_info in datasource_info_list:
workflow_run_id = str(uuid.uuid4())
document_id = None
dataset = pipeline.dataset
if not dataset:
raise ValueError("Dataset not found")
if invoke_from == InvokeFrom.PUBLISHED:
position = DocumentService.get_documents_position(pipeline.dataset_id)
position = DocumentService.get_documents_position(pipeline.dataset_id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=pipeline.dataset_id,
built_in_field_enabled=pipeline.dataset.built_in_field_enabled,
built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=pipeline.dataset.chunk_structure,
document_form=dataset.chunk_structure,
)
db.session.add(document)
db.session.commit()
@ -127,10 +132,12 @@ class PipelineGenerator(BaseAppGenerator):
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
pipline_config=pipeline_config,
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=pipeline.dataset_id,
dataset_id=dataset.id,
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
inputs=self._prepare_user_inputs(
@ -160,17 +167,28 @@ class PipelineGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return self._generate(
pipeline=pipeline,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
if invoke_from == InvokeFrom.DEBUGGER:
return self._generate(
pipeline=pipeline,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
self._generate(
pipeline=pipeline,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
def _generate(
self,
@ -201,7 +219,7 @@ class PipelineGenerator(BaseAppGenerator):
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=pipeline.mode,
app_mode=AppMode.RAG_PIPELINE,
)
# new thread
@ -256,12 +274,18 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")
# convert to app config
app_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args["datasource_type"],
datasource_info=args["datasource_info"],
dataset_id=pipeline.dataset_id,
batch=args["batch"],
document_id=args["document_id"],
inputs={},
files=[],
user_id=user.id,
@ -288,7 +312,7 @@ class PipelineGenerator(BaseAppGenerator):
)
return self._generate(
app_model=app_model,
pipeline=pipeline,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
@ -299,7 +323,7 @@ class PipelineGenerator(BaseAppGenerator):
def single_loop_generate(
self,
app_model: App,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
@ -323,7 +347,7 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required")
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_config = WorkflowAppConfigManager.get_app_config(pipeline=pipeline, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
@ -353,7 +377,7 @@ class PipelineGenerator(BaseAppGenerator):
)
return self._generate(
app_model=app_model,
pipeline=pipeline,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,

@ -1,5 +1,6 @@
import logging
from typing import Optional, cast
from collections.abc import Mapping
from typing import Any, Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -12,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Pipeline
@ -100,6 +102,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
SystemVariableKey.BATCH: self.application_generate_entity.batch,
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
}
variable_pool = VariablePool(
@ -110,7 +114,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
graph = self._init_rag_pipeline_graph(
graph_config=workflow.graph_dict,
start_node_id=self.application_generate_entity.start_node_id,
)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
@ -152,3 +159,43 @@ class PipelineRunner(WorkflowBasedAppRunner):
# return workflow
return workflow
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
"""
Init pipeline graph
"""
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
nodes = graph_config.get("nodes", [])
edges = graph_config.get("edges", [])
real_run_nodes = []
real_edges = []
exclude_node_ids = []
for node in nodes:
node_id = node.get("id")
node_type = node.get("data", {}).get("type", "")
if node_type == "datasource":
if start_node_id != node_id:
exclude_node_ids.append(node_id)
continue
real_run_nodes.append(node)
for edge in edges:
if edge.get("source") in exclude_node_ids :
continue
real_edges.append(edge)
graph_config = dict(graph_config)
graph_config["nodes"] = real_run_nodes
graph_config["edges"] = real_edges
# init graph
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError("graph not found in workflow")
return graph

@ -233,14 +233,14 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""
RAG Pipeline Application Generate Entity.
"""
# app config
pipline_config: WorkflowUIBasedAppConfig
# pipeline config
pipeline_config: WorkflowUIBasedAppConfig
datasource_type: str
datasource_info: Mapping[str, Any]
dataset_id: str
batch: str
document_id: str
document_id: Optional[str] = None
start_node_id: Optional[str] = None
class SingleIterationRunEntity(BaseModel):
"""

@ -18,3 +18,5 @@ class SystemVariableKey(StrEnum):
DOCUMENT_ID = "document_id"
BATCH = "batch"
DATASET_ID = "dataset_id"
DATASOURCE_TYPE = "datasource_type"
DATASOURCE_INFO = "datasource_info"

@ -121,6 +121,8 @@ class Graph(BaseModel):
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id:
@ -140,7 +142,8 @@ class Graph(BaseModel):
(
node_config.get("id")
for node_config in root_node_configs
if node_config.get("data", {}).get("type", "") == NodeType.START.value
if node_config.get("data", {}).get("type", "") == NodeType.START.value
or node_config.get("data", {}).get("type", "") == NodeType.DATASOURCE.value
),
None,
)

@ -6,11 +6,8 @@ from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse,
GetWebsiteCrawlRequest,
GetWebsiteCrawlResponse,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.file import File
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
@ -42,22 +39,23 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
"""
node_data = cast(DatasourceNodeData, self.node_data)
# fetch datasource icon
datasource_info = {
"provider_id": node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier,
}
variable_pool = self.graph_runtime_state.variable_pool
# get datasource runtime
try:
from core.datasource.datasource_manager import DatasourceManager
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=node_data.provider_id,
datasource_name=node_data.datasource_name,
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType(node_data.provider_type),
datasource_type=DatasourceProviderType(datasource_type),
)
except DatasourceNodeError as e:
yield RunCompletedEvent(
@ -75,12 +73,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
datasource_parameters = datasource_runtime.entity.parameters
parameters = self._generate_parameters(
datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
variable_pool=variable_pool,
node_data=self.node_data,
)
parameters_for_log = self._generate_parameters(
datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
variable_pool=variable_pool,
node_data=self.node_data,
for_log=True,
)
@ -106,20 +104,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
},
)
)
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
user_id=self.user_id,
datasource_parameters=GetWebsiteCrawlRequest(**parameters),
provider_type=datasource_runtime.datasource_provider_type(),
elif (
datasource_runtime.datasource_provider_type in (
DatasourceProviderType.WEBSITE_CRAWL,
DatasourceProviderType.LOCAL_FILE,
)
):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"website": website_crawl_result.result.model_dump(),
"website": datasource_info,
"datasource_type": datasource_runtime.datasource_provider_type,
},
)

@ -6,7 +6,7 @@ import random
import time
import uuid
from collections import Counter
from typing import Any, Optional
from typing import Any, Optional, cast
from flask_login import current_user
from sqlalchemy import func, select
@ -298,13 +298,14 @@ class DatasetService:
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag_pipeline",
runtime_mode="rag-pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=current_user,
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset,

@ -59,12 +59,12 @@ class RagPipelineService:
if not result.get("pipeline_templates") and language != "en-US":
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
return result.get("pipeline_templates")
return [PipelineBuiltInTemplate(**template) for template in result.get("pipeline_templates", [])]
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language)
return result.get("pipeline_templates")
return [PipelineCustomizedTemplate(**template) for template in result.get("pipeline_templates", [])]
@classmethod
def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:

@ -97,11 +97,6 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus:
class RagPipelinePendingData(BaseModel):
import_mode: str
yaml_content: str
name: str | None
description: str | None
icon_type: str | None
icon: str | None
icon_background: str | None
pipeline_id: str | None
@ -302,10 +297,6 @@ class RagPipelineDslService:
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality":
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
)
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
@ -445,10 +436,28 @@ class RagPipelineDslService:
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality":
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
DatasetCollectionBinding.provider_name
== knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
DatasetCollectionBinding.model_name
== knowledge_configuration.index_method.embedding_setting.embedding_model_name,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
.first()
)
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
db.session.add(dataset_collection_binding)
db.session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
@ -602,7 +611,6 @@ class RagPipelineDslService:
rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,

Loading…
Cancel
Save