feat/datasource
jyong 12 months ago
parent 7f59ffe7af
commit 797d044714

@ -462,18 +462,6 @@ class PublishedRagPipelineApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.")
args = parser.parse_args()
if not args.get("knowledge_base_setting"):
raise ValueError("Missing knowledge base setting.")
knowledge_base_setting_data = args.get("knowledge_base_setting")
if not knowledge_base_setting_data:
raise ValueError("Missing knowledge base setting.")
knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data)
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session: with Session(db.engine) as session:
pipeline = session.merge(pipeline) pipeline = session.merge(pipeline)
@ -481,7 +469,6 @@ class PublishedRagPipelineApi(Resource):
session=session, session=session,
pipeline=pipeline, pipeline=pipeline,
account=current_user, account=current_user,
knowledge_base_setting=knowledge_base_setting,
) )
pipeline.is_published = True pipeline.is_published = True
pipeline.workflow_id = workflow.id pipeline.workflow_id = workflow.id

@ -22,6 +22,7 @@ class PluginDatasourceManager(BasePluginClient):
""" """
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]) -> dict:
if json_response.get("data"):
for provider in json_response.get("data", []): for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {} declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name") provider_name = declaration.get("identity", {}).get("name")

@ -9,6 +9,7 @@ from core.datasource.entities.datasource_entities import (
) )
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.file import File from core.file import File
from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment, FileSegment from core.variables.segments import ArrayAnySegment, FileSegment
from core.variables.variables import ArrayAnyVariable from core.variables.variables import ArrayAnyVariable
@ -118,7 +119,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
}, },
) )
case DatasourceProviderType.LOCAL_FILE: case DatasourceProviderType.LOCAL_FILE:
upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first() related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError(
"File is not exist"
)
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
if not upload_file: if not upload_file:
raise ValueError("Invalid upload file Info") raise ValueError("Invalid upload file Info")
@ -128,14 +134,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
extension="." + upload_file.extension, extension="." + upload_file.extension,
mime_type=upload_file.mime_type, mime_type=upload_file.mime_type,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=datasource_info.get("type", ""), type=FileType.CUSTOM,
transfer_method=datasource_info.get("transfer_method", ""), transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url, remote_url=upload_file.source_url,
related_id=upload_file.id, related_id=upload_file.id,
size=upload_file.size, size=upload_file.size,
storage_key=upload_file.key, storage_key=upload_file.key,
) )
variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)]) variable_pool.add([self.node_id, "file"], [file_info])
for key, value in datasource_info.items(): for key, value in datasource_info.items():
# construct new key list # construct new key list
new_key_list = ["file", key] new_key_list = ["file", key]
@ -147,7 +153,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={ outputs={
"file_info": file_info, "file_info": datasource_info,
"datasource_type": datasource_type, "datasource_type": datasource_type,
}, },
) )

@ -53,6 +53,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
) )
from services.entities.knowledge_entities.rag_pipeline_entities import ( from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeBaseUpdateConfiguration, KnowledgeBaseUpdateConfiguration,
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity, RagPipelineDatasetCreateEntity,
) )
from services.errors.account import InvalidActionError, NoPermissionError from services.errors.account import InvalidActionError, NoPermissionError
@ -495,11 +496,11 @@ class DatasetService:
@staticmethod @staticmethod
def update_rag_pipeline_dataset_settings(session: Session, def update_rag_pipeline_dataset_settings(session: Session,
dataset: Dataset, dataset: Dataset,
knowledge_base_setting: KnowledgeBaseUpdateConfiguration, knowledge_configuration: KnowledgeConfiguration,
has_published: bool = False): has_published: bool = False):
if not has_published: if not has_published:
dataset.chunk_structure = knowledge_base_setting.chunk_structure dataset.chunk_structure = knowledge_configuration.chunk_structure
index_method = knowledge_base_setting.index_method index_method = knowledge_configuration.index_method
dataset.indexing_technique = index_method.indexing_technique dataset.indexing_technique = index_method.indexing_technique
if index_method == "high_quality": if index_method == "high_quality":
model_manager = ModelManager() model_manager = ModelManager()
@ -519,26 +520,26 @@ class DatasetService:
dataset.keyword_number = index_method.economy_setting.keyword_number dataset.keyword_number = index_method.economy_setting.keyword_number
else: else:
raise ValueError("Invalid index method") raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
session.add(dataset) session.add(dataset)
else: else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure: if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
raise ValueError("Chunk structure is not allowed to be updated.") raise ValueError("Chunk structure is not allowed to be updated.")
action = None action = None
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique: if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique:
# if update indexing_technique # if update indexing_technique
if knowledge_base_setting.index_method.indexing_technique == "economy": if knowledge_configuration.index_method.indexing_technique == "economy":
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_base_setting.index_method.indexing_technique == "high_quality": elif knowledge_configuration.index_method.indexing_technique == "high_quality":
action = "add" action = "add"
# get embedding model setting # get embedding model setting
try: try:
model_manager = ModelManager() model_manager = ModelManager()
embedding_model = model_manager.get_model_instance( embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, model=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
) )
dataset.embedding_model = embedding_model.model dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider dataset.embedding_model_provider = embedding_model.provider
@ -607,9 +608,9 @@ class DatasetService:
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ValueError(ex.description) raise ValueError(ex.description)
elif dataset.indexing_technique == "economy": elif dataset.indexing_technique == "economy":
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number: if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number:
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
session.add(dataset) session.add(dataset)
session.commit() session.commit()
if action: if action:

@ -47,7 +47,7 @@ from models.workflow import (
WorkflowType, WorkflowType,
) )
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
@ -262,7 +262,6 @@ class RagPipelineService:
session: Session, session: Session,
pipeline: Pipeline, pipeline: Pipeline,
account: Account, account: Account,
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
) -> Workflow: ) -> Workflow:
draft_workflow_stmt = select(Workflow).where( draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == pipeline.tenant_id, Workflow.tenant_id == pipeline.tenant_id,
@ -291,6 +290,13 @@ class RagPipelineService:
# commit db session changes # commit db session changes
session.add(workflow) session.add(workflow)
graph = workflow.graph_dict
nodes = graph.get("nodes", [])
for node in nodes:
if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
# update dataset # update dataset
dataset = pipeline.dataset dataset = pipeline.dataset
if not dataset: if not dataset:
@ -298,7 +304,7 @@ class RagPipelineService:
DatasetService.update_rag_pipeline_dataset_settings( DatasetService.update_rag_pipeline_dataset_settings(
session=session, session=session,
dataset=dataset, dataset=dataset,
knowledge_base_setting=knowledge_base_setting, knowledge_configuration=knowledge_configuration,
has_published=pipeline.is_published has_published=pipeline.is_published
) )
# return new workflow # return new workflow

Loading…
Cancel
Save