feat/datasource
jyong 11 months ago
parent 7f59ffe7af
commit 797d044714

@ -462,18 +462,6 @@ class PublishedRagPipelineApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.")
args = parser.parse_args()
if not args.get("knowledge_base_setting"):
raise ValueError("Missing knowledge base setting.")
knowledge_base_setting_data = args.get("knowledge_base_setting")
if not knowledge_base_setting_data:
raise ValueError("Missing knowledge base setting.")
knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data)
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session: with Session(db.engine) as session:
pipeline = session.merge(pipeline) pipeline = session.merge(pipeline)
@ -481,7 +469,6 @@ class PublishedRagPipelineApi(Resource):
session=session, session=session,
pipeline=pipeline, pipeline=pipeline,
account=current_user, account=current_user,
knowledge_base_setting=knowledge_base_setting,
) )
pipeline.is_published = True pipeline.is_published = True
pipeline.workflow_id = workflow.id pipeline.workflow_id = workflow.id

@ -22,11 +22,12 @@ class PluginDatasourceManager(BasePluginClient):
""" """
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]) -> dict:
for provider in json_response.get("data", []): if json_response.get("data"):
declaration = provider.get("declaration", {}) or {} for provider in json_response.get("data", []):
provider_name = declaration.get("identity", {}).get("name") declaration = provider.get("declaration", {}) or {}
for datasource in declaration.get("datasources", []): provider_name = declaration.get("identity", {}).get("name")
datasource["identity"]["provider"] = provider_name for datasource in declaration.get("datasources", []):
datasource["identity"]["provider"] = provider_name
return json_response return json_response

@ -9,6 +9,7 @@ from core.datasource.entities.datasource_entities import (
) )
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.file import File from core.file import File
from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment, FileSegment from core.variables.segments import ArrayAnySegment, FileSegment
from core.variables.variables import ArrayAnyVariable from core.variables.variables import ArrayAnyVariable
@ -118,7 +119,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
}, },
) )
case DatasourceProviderType.LOCAL_FILE: case DatasourceProviderType.LOCAL_FILE:
upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first() related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError(
"File is not exist"
)
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
if not upload_file: if not upload_file:
raise ValueError("Invalid upload file Info") raise ValueError("Invalid upload file Info")
@ -128,14 +134,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
extension="." + upload_file.extension, extension="." + upload_file.extension,
mime_type=upload_file.mime_type, mime_type=upload_file.mime_type,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=datasource_info.get("type", ""), type=FileType.CUSTOM,
transfer_method=datasource_info.get("transfer_method", ""), transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url, remote_url=upload_file.source_url,
related_id=upload_file.id, related_id=upload_file.id,
size=upload_file.size, size=upload_file.size,
storage_key=upload_file.key, storage_key=upload_file.key,
) )
variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)]) variable_pool.add([self.node_id, "file"], [file_info])
for key, value in datasource_info.items(): for key, value in datasource_info.items():
# construct new key list # construct new key list
new_key_list = ["file", key] new_key_list = ["file", key]
@ -147,7 +153,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={ outputs={
"file_info": file_info, "file_info": datasource_info,
"datasource_type": datasource_type, "datasource_type": datasource_type,
}, },
) )
@ -220,7 +226,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else [] return list(variable.value) if variable else []
def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
""" """

@ -53,6 +53,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
) )
from services.entities.knowledge_entities.rag_pipeline_entities import ( from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeBaseUpdateConfiguration, KnowledgeBaseUpdateConfiguration,
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity, RagPipelineDatasetCreateEntity,
) )
from services.errors.account import InvalidActionError, NoPermissionError from services.errors.account import InvalidActionError, NoPermissionError
@ -495,11 +496,11 @@ class DatasetService:
@staticmethod @staticmethod
def update_rag_pipeline_dataset_settings(session: Session, def update_rag_pipeline_dataset_settings(session: Session,
dataset: Dataset, dataset: Dataset,
knowledge_base_setting: KnowledgeBaseUpdateConfiguration, knowledge_configuration: KnowledgeConfiguration,
has_published: bool = False): has_published: bool = False):
if not has_published: if not has_published:
dataset.chunk_structure = knowledge_base_setting.chunk_structure dataset.chunk_structure = knowledge_configuration.chunk_structure
index_method = knowledge_base_setting.index_method index_method = knowledge_configuration.index_method
dataset.indexing_technique = index_method.indexing_technique dataset.indexing_technique = index_method.indexing_technique
if index_method == "high_quality": if index_method == "high_quality":
model_manager = ModelManager() model_manager = ModelManager()
@ -519,26 +520,26 @@ class DatasetService:
dataset.keyword_number = index_method.economy_setting.keyword_number dataset.keyword_number = index_method.economy_setting.keyword_number
else: else:
raise ValueError("Invalid index method") raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
session.add(dataset) session.add(dataset)
else: else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure: if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
raise ValueError("Chunk structure is not allowed to be updated.") raise ValueError("Chunk structure is not allowed to be updated.")
action = None action = None
if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique: if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique:
# if update indexing_technique # if update indexing_technique
if knowledge_base_setting.index_method.indexing_technique == "economy": if knowledge_configuration.index_method.indexing_technique == "economy":
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_base_setting.index_method.indexing_technique == "high_quality": elif knowledge_configuration.index_method.indexing_technique == "high_quality":
action = "add" action = "add"
# get embedding model setting # get embedding model setting
try: try:
model_manager = ModelManager() model_manager = ModelManager()
embedding_model = model_manager.get_model_instance( embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, model=knowledge_configuration.index_method.embedding_setting.embedding_model_name,
) )
dataset.embedding_model = embedding_model.model dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider dataset.embedding_model_provider = embedding_model.provider
@ -607,9 +608,9 @@ class DatasetService:
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ValueError(ex.description) raise ValueError(ex.description)
elif dataset.indexing_technique == "economy": elif dataset.indexing_technique == "economy":
if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number: if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number:
dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
session.add(dataset) session.add(dataset)
session.commit() session.commit()
if action: if action:

@ -47,7 +47,7 @@ from models.workflow import (
WorkflowType, WorkflowType,
) )
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
@ -262,7 +262,6 @@ class RagPipelineService:
session: Session, session: Session,
pipeline: Pipeline, pipeline: Pipeline,
account: Account, account: Account,
knowledge_base_setting: KnowledgeBaseUpdateConfiguration,
) -> Workflow: ) -> Workflow:
draft_workflow_stmt = select(Workflow).where( draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == pipeline.tenant_id, Workflow.tenant_id == pipeline.tenant_id,
@ -291,16 +290,23 @@ class RagPipelineService:
# commit db session changes # commit db session changes
session.add(workflow) session.add(workflow)
# update dataset graph = workflow.graph_dict
dataset = pipeline.dataset nodes = graph.get("nodes", [])
if not dataset: for node in nodes:
raise ValueError("Dataset not found") if node.get("data", {}).get("type") == "knowledge_index":
DatasetService.update_rag_pipeline_dataset_settings( knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
session=session, knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
dataset=dataset,
knowledge_base_setting=knowledge_base_setting, # update dataset
has_published=pipeline.is_published dataset = pipeline.dataset
) if not dataset:
raise ValueError("Dataset not found")
DatasetService.update_rag_pipeline_dataset_settings(
session=session,
dataset=dataset,
knowledge_configuration=knowledge_configuration,
has_published=pipeline.is_published
)
# return new workflow # return new workflow
return workflow return workflow

Loading…
Cancel
Save