feat/datasource
jyong 12 months ago
parent 3fb02a7933
commit 0486aa3445

@ -664,7 +664,7 @@ class DocumentDetailApi(DocumentResource):
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without": elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,

@ -39,8 +39,6 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.account import Account from models.account import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from models.model import EndUser
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService

@ -12,7 +12,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from ..base import BaseNode from ..base import BaseNode
@ -61,11 +61,11 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
) )
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
# retrieve knowledge # index knowledge
try: try:
if is_preview: if is_preview:
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables, inputs=variables,
@ -116,6 +116,18 @@ class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
document.indexing_status = "completed" document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(document) db.session.add(document)
#update document segment status
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
)
db.session.commit() db.session.commit()
return { return {

@ -1,3 +1,4 @@
from calendar import day_abbr
import copy import copy
import datetime import datetime
import json import json
@ -52,7 +53,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
SegmentUpdateArgs, SegmentUpdateArgs,
) )
from services.entities.knowledge_entities.rag_pipeline_entities import ( from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeBaseUpdateConfiguration,
KnowledgeConfiguration, KnowledgeConfiguration,
RagPipelineDatasetCreateEntity, RagPipelineDatasetCreateEntity,
) )
@ -498,17 +498,17 @@ class DatasetService:
dataset: Dataset, dataset: Dataset,
knowledge_configuration: KnowledgeConfiguration, knowledge_configuration: KnowledgeConfiguration,
has_published: bool = False): has_published: bool = False):
dataset = session.merge(dataset)
if not has_published: if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure dataset.chunk_structure = knowledge_configuration.chunk_structure
index_method = knowledge_configuration.index_method dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.indexing_technique = index_method.indexing_technique if knowledge_configuration.indexing_technique == "high_quality":
if index_method == "high_quality":
model_manager = ModelManager() model_manager = ModelManager()
embedding_model = model_manager.get_model_instance( embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=index_method.embedding_setting.embedding_provider_name, provider=knowledge_configuration.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=index_method.embedding_setting.embedding_model_name, model=knowledge_configuration.embedding_model,
) )
dataset.embedding_model = embedding_model.model dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider dataset.embedding_model_provider = embedding_model.provider
@ -516,30 +516,30 @@ class DatasetService:
embedding_model.provider, embedding_model.model embedding_model.provider, embedding_model.model
) )
dataset.collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding.id
elif index_method == "economy": elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = index_method.economy_setting.keyword_number dataset.keyword_number = knowledge_configuration.keyword_number
else: else:
raise ValueError("Invalid index method") raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
session.add(dataset) session.add(dataset)
else: else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
raise ValueError("Chunk structure is not allowed to be updated.") raise ValueError("Chunk structure is not allowed to be updated.")
action = None action = None
if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique: if dataset.indexing_technique != knowledge_configuration.indexing_technique:
# if update indexing_technique # if update indexing_technique
if knowledge_configuration.index_method.indexing_technique == "economy": if knowledge_configuration.indexing_technique == "economy":
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_configuration.index_method.indexing_technique == "high_quality": elif knowledge_configuration.indexing_technique == "high_quality":
action = "add" action = "add"
# get embedding model setting # get embedding model setting
try: try:
model_manager = ModelManager() model_manager = ModelManager()
embedding_model = model_manager.get_model_instance( embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, provider=knowledge_configuration.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.index_method.embedding_setting.embedding_model_name, model=knowledge_configuration.embedding_model,
) )
dataset.embedding_model = embedding_model.model dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider dataset.embedding_model_provider = embedding_model.provider
@ -567,7 +567,7 @@ class DatasetService:
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
# Handle new model provider from request # Handle new model provider from request
new_plugin_model_provider = knowledge_base_setting.index_method.embedding_setting.embedding_provider_name new_plugin_model_provider = knowledge_configuration.embedding_model_provider
new_plugin_model_provider_str = None new_plugin_model_provider_str = None
if new_plugin_model_provider: if new_plugin_model_provider:
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
@ -575,16 +575,16 @@ class DatasetService:
# Only update embedding model if both values are provided and different from current # Only update embedding model if both values are provided and different from current
if ( if (
plugin_model_provider_str != new_plugin_model_provider_str plugin_model_provider_str != new_plugin_model_provider_str
or knowledge_base_setting.index_method.embedding_setting.embedding_model_name != dataset.embedding_model or knowledge_configuration.embedding_model != dataset.embedding_model
): ):
action = "update" action = "update"
model_manager = ModelManager() model_manager = ModelManager()
try: try:
embedding_model = model_manager.get_model_instance( embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, provider=knowledge_configuration.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, model=knowledge_configuration.embedding_model,
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError:
# If we can't get the embedding model, skip updating it # If we can't get the embedding model, skip updating it
@ -608,9 +608,9 @@ class DatasetService:
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ValueError(ex.description) raise ValueError(ex.description)
elif dataset.indexing_technique == "economy": elif dataset.indexing_technique == "economy":
if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number: if dataset.keyword_number != knowledge_configuration.keyword_number:
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number dataset.keyword_number = knowledge_configuration.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
session.add(dataset) session.add(dataset)
session.commit() session.commit()
if action: if action:

@ -105,18 +105,11 @@ class IndexMethod(BaseModel):
class KnowledgeConfiguration(BaseModel): class KnowledgeConfiguration(BaseModel):
""" """
Knowledge Configuration. Knowledge Base Configuration.
""" """
chunk_structure: str chunk_structure: str
index_method: IndexMethod indexing_technique: Literal["high_quality", "economy"]
retrieval_setting: RetrievalSetting embedding_model_provider: Optional[str] = ""
embedding_model: Optional[str] = ""
keyword_number: Optional[int] = 10
class KnowledgeBaseUpdateConfiguration(BaseModel): retrieval_model: RetrievalSetting
"""
Knowledge Base Update Configuration.
"""
index_method: IndexMethod
chunk_structure: str
retrieval_setting: RetrievalSetting

@ -296,8 +296,8 @@ class RagPipelineService:
graph = workflow.graph_dict graph = workflow.graph_dict
nodes = graph.get("nodes", []) nodes = graph.get("nodes", [])
for node in nodes: for node in nodes:
if node.get("data", {}).get("type") == "knowledge_index": if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) knowledge_configuration = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
# update dataset # update dataset

@ -1,10 +1,10 @@
import base64 import base64
from datetime import UTC, datetime
import hashlib import hashlib
import json import json
import logging import logging
import uuid import uuid
from collections.abc import Mapping from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum from enum import StrEnum
from typing import Optional, cast from typing import Optional, cast
from urllib.parse import urlparse from urllib.parse import urlparse
@ -292,20 +292,20 @@ class RagPipelineDslService:
"background": icon_background, "background": icon_background,
"url": icon_url, "url": icon_url,
}, },
indexing_technique=knowledge_configuration.index_method.indexing_technique, indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id, created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_setting.model_dump(), retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline", runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure, chunk_structure=knowledge_configuration.chunk_structure,
) )
if knowledge_configuration.index_method.indexing_technique == "high_quality": if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter( .filter(
DatasetCollectionBinding.provider_name DatasetCollectionBinding.provider_name
== knowledge_configuration.index_method.embedding_setting.embedding_provider_name, == knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name DatasetCollectionBinding.model_name
== knowledge_configuration.index_method.embedding_setting.embedding_model_name, == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset", DatasetCollectionBinding.type == "dataset",
) )
.order_by(DatasetCollectionBinding.created_at) .order_by(DatasetCollectionBinding.created_at)
@ -314,8 +314,8 @@ class RagPipelineDslService:
if not dataset_collection_binding: if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding( dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name, model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset", type="dataset",
) )
@ -324,13 +324,13 @@ class RagPipelineDslService:
dataset_collection_binding_id = dataset_collection_binding.id dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = ( dataset.embedding_model = (
knowledge_configuration.index_method.embedding_setting.embedding_model_name knowledge_configuration.embedding_model
) )
dataset.embedding_model_provider = ( dataset.embedding_model_provider = (
knowledge_configuration.index_method.embedding_setting.embedding_provider_name knowledge_configuration.embedding_model_provider
) )
elif knowledge_configuration.index_method.indexing_technique == "economy": elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id dataset.pipeline_id = pipeline.id
self._session.add(dataset) self._session.add(dataset)
self._session.commit() self._session.commit()
@ -426,25 +426,25 @@ class RagPipelineDslService:
"background": icon_background, "background": icon_background,
"url": icon_url, "url": icon_url,
}, },
indexing_technique=knowledge_configuration.index_method.indexing_technique, indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id, created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_setting.model_dump(), retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline", runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure, chunk_structure=knowledge_configuration.chunk_structure,
) )
else: else:
dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
dataset.runtime_mode = "rag_pipeline" dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality": if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter( .filter(
DatasetCollectionBinding.provider_name DatasetCollectionBinding.provider_name
== knowledge_configuration.index_method.embedding_setting.embedding_provider_name, == knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name DatasetCollectionBinding.model_name
== knowledge_configuration.index_method.embedding_setting.embedding_model_name, == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset", DatasetCollectionBinding.type == "dataset",
) )
.order_by(DatasetCollectionBinding.created_at) .order_by(DatasetCollectionBinding.created_at)
@ -453,8 +453,8 @@ class RagPipelineDslService:
if not dataset_collection_binding: if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding( dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.index_method.embedding_setting.embedding_model_name, model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset", type="dataset",
) )
@ -463,13 +463,13 @@ class RagPipelineDslService:
dataset_collection_binding_id = dataset_collection_binding.id dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = ( dataset.embedding_model = (
knowledge_configuration.index_method.embedding_setting.embedding_model_name knowledge_configuration.embedding_model
) )
dataset.embedding_model_provider = ( dataset.embedding_model_provider = (
knowledge_configuration.index_method.embedding_setting.embedding_provider_name knowledge_configuration.embedding_model_provider
) )
elif knowledge_configuration.index_method.indexing_technique == "economy": elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id dataset.pipeline_id = pipeline.id
self._session.add(dataset) self._session.add(dataset)
self._session.commit() self._session.commit()

Loading…
Cancel
Save