From 227f9c11bc895c277e67c64bd477e7c904a9fdce Mon Sep 17 00:00:00 2001 From: VanKhoa Date: Fri, 20 Jun 2025 23:40:39 +0700 Subject: [PATCH] feat(api): Validate model provider when create/update dataset/document --- .../service_api/dataset/dataset.py | 26 +++++++++++++ .../service_api/dataset/document.py | 38 ++++++++++++++++++- api/services/dataset_service.py | 17 +++++++++ 3 files changed, 80 insertions(+), 1 deletion(-) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 1467dfb6b3..24712cfd77 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -133,6 +133,20 @@ class DatasetListApi(DatasetApiResource): parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() + + if args.get("embedding_model_provider"): + DatasetService.check_embedding_model_setting( + tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") + ) + if (args.get("retrieval_model") and + args.get("retrieval_model").get("reranking_model") and + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")): + DatasetService.check_reranking_model_setting( + tenant_id, + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + args.get("retrieval_model").get("reranking_model").get("reranking_model_name") + ) + try: dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, @@ -269,6 +283,18 @@ class DatasetApi(DatasetApiResource): DatasetService.check_embedding_model_setting( dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") ) + if data.get("embedding_model_provider"): + DatasetService.check_embedding_model_setting( + dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") + ) + if (data.get("retrieval_model") and + data.get("retrieval_model").get("reranking_model") and + data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")): + DatasetService.check_reranking_model_setting( + dataset.tenant_id, + data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + data.get("retrieval_model").get("reranking_model").get("reranking_model_name") + ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index e4779f3bdf..fa75700355 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -29,7 +29,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment -from services.dataset_service import DocumentService +from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -59,6 +59,7 @@ class DocumentAddByTextApi(DatasetApiResource): parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() + dataset_id = str(dataset_id) tenant_id = str(tenant_id) dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -74,6 +75,19 @@ class DocumentAddByTextApi(DatasetApiResource): if text is None or name is None: raise ValueError("Both 'text' and 'name' must be non-null values.") + if args.get("embedding_model_provider"): + DatasetService.check_embedding_model_setting( + tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") + ) + if (args.get("retrieval_model") and + args.get("retrieval_model").get("reranking_model") and + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")): + DatasetService.check_reranking_model_setting( + tenant_id, + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + args.get("retrieval_model").get("reranking_model").get("reranking_model_name") + ) + upload_file = FileService.upload_text(text=str(text), text_name=str(name)) data_source = { "type": "upload_file", @@ -124,6 +138,15 @@ class DocumentUpdateByTextApi(DatasetApiResource): if not dataset: raise ValueError("Dataset does not exist.") + if (args.get("retrieval_model") and + args.get("retrieval_model").get("reranking_model") and + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")): + DatasetService.check_reranking_model_setting( + tenant_id, + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + args.get("retrieval_model").get("reranking_model").get("reranking_model_name") + ) + # indexing_technique is already set in dataset since this is an update args["indexing_technique"] = dataset.indexing_technique @@ -188,6 +211,19 @@ class DocumentAddByFileApi(DatasetApiResource): raise ValueError("indexing_technique is required.") args["indexing_technique"] = indexing_technique + if "embedding_model_provider" in args: + DatasetService.check_embedding_model_setting( + tenant_id, args["embedding_model_provider"], args["embedding_model"] + ) + if ("retrieval_model" in args and + args["retrieval_model"].get("reranking_model") and + args["retrieval_model"].get("reranking_model").get("reranking_provider_name")): + DatasetService.check_reranking_model_setting( + tenant_id, + args["retrieval_model"].get("reranking_model").get("reranking_provider_name"), + args["retrieval_model"].get("reranking_model").get("reranking_model_name") + ) + # save file info file = request.files["file"] # check file diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index a29bf92596..ff20010edd 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -276,6 +276,23 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) + @staticmethod + def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): + try: + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model_provider, + model_type=ModelType.RERANK, + model=reranking_model, + ) + except LLMBadRequestError: + raise ValueError( + "No Rerank Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + @staticmethod def update_dataset(dataset_id, data, user): dataset = DatasetService.get_dataset(dataset_id)