|
|
|
|
@ -184,7 +184,7 @@ class DatasetService:
|
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_dataset(dataset_id):
|
|
|
|
|
def get_dataset(dataset_id) -> Dataset:
|
|
|
|
|
return Dataset.query.filter_by(id=dataset_id).first()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@ -225,81 +225,103 @@ class DatasetService:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def update_dataset(dataset_id, data, user):
|
|
|
|
|
data.pop("partial_member_list", None)
|
|
|
|
|
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
|
|
|
|
|
dataset = DatasetService.get_dataset(dataset_id)
|
|
|
|
|
|
|
|
|
|
DatasetService.check_dataset_permission(dataset, user)
|
|
|
|
|
action = None
|
|
|
|
|
if dataset.indexing_technique != data["indexing_technique"]:
|
|
|
|
|
# if update indexing_technique
|
|
|
|
|
if data["indexing_technique"] == "economy":
|
|
|
|
|
action = "remove"
|
|
|
|
|
filtered_data["embedding_model"] = None
|
|
|
|
|
filtered_data["embedding_model_provider"] = None
|
|
|
|
|
filtered_data["collection_binding_id"] = None
|
|
|
|
|
elif data["indexing_technique"] == "high_quality":
|
|
|
|
|
action = "add"
|
|
|
|
|
# get embedding model setting
|
|
|
|
|
try:
|
|
|
|
|
model_manager = ModelManager()
|
|
|
|
|
embedding_model = model_manager.get_model_instance(
|
|
|
|
|
tenant_id=current_user.current_tenant_id,
|
|
|
|
|
provider=data["embedding_model_provider"],
|
|
|
|
|
model_type=ModelType.TEXT_EMBEDDING,
|
|
|
|
|
model=data["embedding_model"],
|
|
|
|
|
)
|
|
|
|
|
filtered_data["embedding_model"] = embedding_model.model
|
|
|
|
|
filtered_data["embedding_model_provider"] = embedding_model.provider
|
|
|
|
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
|
|
|
embedding_model.provider, embedding_model.model
|
|
|
|
|
)
|
|
|
|
|
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
|
|
|
|
except LLMBadRequestError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"No Embedding Model available. Please configure a valid provider "
|
|
|
|
|
"in the Settings -> Model Provider."
|
|
|
|
|
)
|
|
|
|
|
except ProviderTokenNotInitError as ex:
|
|
|
|
|
raise ValueError(ex.description)
|
|
|
|
|
else:
|
|
|
|
|
if dataset.provider == "external":
|
|
|
|
|
dataset.retrieval_model = data.get("external_retrieval_model", None)
|
|
|
|
|
dataset.name = data.get("name", dataset.name)
|
|
|
|
|
dataset.description = data.get("description", "")
|
|
|
|
|
external_knowledge_id = data.get("external_knowledge_id", None)
|
|
|
|
|
db.session.add(dataset)
|
|
|
|
|
if not external_knowledge_id:
|
|
|
|
|
raise ValueError("External knowledge id is required.")
|
|
|
|
|
external_knowledge_api_id = data.get("external_knowledge_api_id", None)
|
|
|
|
|
if not external_knowledge_api_id:
|
|
|
|
|
raise ValueError("External knowledge api id is required.")
|
|
|
|
|
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(dataset_id=dataset_id).first()
|
|
|
|
|
if (
|
|
|
|
|
data["embedding_model_provider"] != dataset.embedding_model_provider
|
|
|
|
|
or data["embedding_model"] != dataset.embedding_model
|
|
|
|
|
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
|
|
|
|
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
|
|
|
|
):
|
|
|
|
|
action = "update"
|
|
|
|
|
try:
|
|
|
|
|
model_manager = ModelManager()
|
|
|
|
|
embedding_model = model_manager.get_model_instance(
|
|
|
|
|
tenant_id=current_user.current_tenant_id,
|
|
|
|
|
provider=data["embedding_model_provider"],
|
|
|
|
|
model_type=ModelType.TEXT_EMBEDDING,
|
|
|
|
|
model=data["embedding_model"],
|
|
|
|
|
)
|
|
|
|
|
filtered_data["embedding_model"] = embedding_model.model
|
|
|
|
|
filtered_data["embedding_model_provider"] = embedding_model.provider
|
|
|
|
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
|
|
|
embedding_model.provider, embedding_model.model
|
|
|
|
|
)
|
|
|
|
|
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
|
|
|
|
except LLMBadRequestError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"No Embedding Model available. Please configure a valid provider "
|
|
|
|
|
"in the Settings -> Model Provider."
|
|
|
|
|
)
|
|
|
|
|
except ProviderTokenNotInitError as ex:
|
|
|
|
|
raise ValueError(ex.description)
|
|
|
|
|
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
|
|
|
|
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
|
|
|
|
db.session.add(external_knowledge_binding)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
else:
|
|
|
|
|
data.pop("partial_member_list", None)
|
|
|
|
|
filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
|
|
|
|
|
action = None
|
|
|
|
|
if dataset.indexing_technique != data["indexing_technique"]:
|
|
|
|
|
# if update indexing_technique
|
|
|
|
|
if data["indexing_technique"] == "economy":
|
|
|
|
|
action = "remove"
|
|
|
|
|
filtered_data["embedding_model"] = None
|
|
|
|
|
filtered_data["embedding_model_provider"] = None
|
|
|
|
|
filtered_data["collection_binding_id"] = None
|
|
|
|
|
elif data["indexing_technique"] == "high_quality":
|
|
|
|
|
action = "add"
|
|
|
|
|
# get embedding model setting
|
|
|
|
|
try:
|
|
|
|
|
model_manager = ModelManager()
|
|
|
|
|
embedding_model = model_manager.get_model_instance(
|
|
|
|
|
tenant_id=current_user.current_tenant_id,
|
|
|
|
|
provider=data["embedding_model_provider"],
|
|
|
|
|
model_type=ModelType.TEXT_EMBEDDING,
|
|
|
|
|
model=data["embedding_model"],
|
|
|
|
|
)
|
|
|
|
|
filtered_data["embedding_model"] = embedding_model.model
|
|
|
|
|
filtered_data["embedding_model_provider"] = embedding_model.provider
|
|
|
|
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
|
|
|
embedding_model.provider, embedding_model.model
|
|
|
|
|
)
|
|
|
|
|
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
|
|
|
|
except LLMBadRequestError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"No Embedding Model available. Please configure a valid provider "
|
|
|
|
|
"in the Settings -> Model Provider."
|
|
|
|
|
)
|
|
|
|
|
except ProviderTokenNotInitError as ex:
|
|
|
|
|
raise ValueError(ex.description)
|
|
|
|
|
else:
|
|
|
|
|
if (
|
|
|
|
|
data["embedding_model_provider"] != dataset.embedding_model_provider
|
|
|
|
|
or data["embedding_model"] != dataset.embedding_model
|
|
|
|
|
):
|
|
|
|
|
action = "update"
|
|
|
|
|
try:
|
|
|
|
|
model_manager = ModelManager()
|
|
|
|
|
embedding_model = model_manager.get_model_instance(
|
|
|
|
|
tenant_id=current_user.current_tenant_id,
|
|
|
|
|
provider=data["embedding_model_provider"],
|
|
|
|
|
model_type=ModelType.TEXT_EMBEDDING,
|
|
|
|
|
model=data["embedding_model"],
|
|
|
|
|
)
|
|
|
|
|
filtered_data["embedding_model"] = embedding_model.model
|
|
|
|
|
filtered_data["embedding_model_provider"] = embedding_model.provider
|
|
|
|
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
|
|
|
embedding_model.provider, embedding_model.model
|
|
|
|
|
)
|
|
|
|
|
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
|
|
|
|
except LLMBadRequestError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"No Embedding Model available. Please configure a valid provider "
|
|
|
|
|
"in the Settings -> Model Provider."
|
|
|
|
|
)
|
|
|
|
|
except ProviderTokenNotInitError as ex:
|
|
|
|
|
raise ValueError(ex.description)
|
|
|
|
|
|
|
|
|
|
filtered_data["updated_by"] = user.id
|
|
|
|
|
filtered_data["updated_at"] = datetime.datetime.now()
|
|
|
|
|
filtered_data["updated_by"] = user.id
|
|
|
|
|
filtered_data["updated_at"] = datetime.datetime.now()
|
|
|
|
|
|
|
|
|
|
# update Retrieval model
|
|
|
|
|
filtered_data["retrieval_model"] = data["retrieval_model"]
|
|
|
|
|
# update Retrieval model
|
|
|
|
|
filtered_data["retrieval_model"] = data["retrieval_model"]
|
|
|
|
|
|
|
|
|
|
dataset.query.filter_by(id=dataset_id).update(filtered_data)
|
|
|
|
|
dataset.query.filter_by(id=dataset_id).update(filtered_data)
|
|
|
|
|
|
|
|
|
|
db.session.commit()
|
|
|
|
|
if action:
|
|
|
|
|
deal_dataset_vector_index_task.delay(dataset_id, action)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
if action:
|
|
|
|
|
deal_dataset_vector_index_task.delay(dataset_id, action)
|
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|