From 33d09049819d30f36d77e79bb89bea4fef3a29e9 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 11 Oct 2024 19:13:19 +0800 Subject: [PATCH] fix knowledge permission update --- api/configs/middleware/__init__.py | 17 +++++++++++++++++ api/configs/middleware/vdb/qdrant_config.py | 10 ---------- api/controllers/console/datasets/external.py | 16 +++++++++++----- .../dataset_retriever/dataset_retriever_tool.py | 13 +++++++------ api/services/knowledge_service.py | 15 ++++++++------- 5 files changed, 43 insertions(+), 28 deletions(-) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 8626236856..5fec991d6e 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -191,6 +191,22 @@ class CeleryConfig(DatabaseConfig): return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False +class InternalTestConfig(BaseSettings): + """ + Configuration settings for Internal Test + """ + + AWS_SECRET_ACCESS_KEY: Optional[str] = Field( + description="Internal test AWS secret access key", + default=None, + ) + + AWS_ACCESS_KEY_ID: Optional[str] = Field( + description="Internal test AWS access key ID", + default=None, + ) + + class MiddlewareConfig( # place the configs in alphabet order CeleryConfig, @@ -224,5 +240,6 @@ class MiddlewareConfig( TiDBVectorConfig, WeaviateConfig, ElasticsearchConfig, + InternalTestConfig, ): pass diff --git a/api/configs/middleware/vdb/qdrant_config.py b/api/configs/middleware/vdb/qdrant_config.py index bbeb3195b6..b70f624652 100644 --- a/api/configs/middleware/vdb/qdrant_config.py +++ b/api/configs/middleware/vdb/qdrant_config.py @@ -33,13 +33,3 @@ class QdrantConfig(BaseSettings): description="Port number for gRPC connection to Qdrant server (default is 6334)", default=6334, ) - - AWS_SECRET_ACCESS_KEY: Optional[str] = Field( - description="AWS secret access key for authenticating with the Qdrant server", - default=None, - ) - - AWS_ACCESS_KEY_ID: Optional[str] = Field( - description="AWS access key ID for authenticating with the Qdrant server", - default=None, - ) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 1ef215961e..2dc054cfbd 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -13,7 +13,7 @@ from libs.login import login_required from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService -from services.knowledge_service import ExternalDatasetServiceTest +from services.knowledge_service import ExternalDatasetTestService def _validate_name(name): @@ -234,16 +234,21 @@ class ExternalKnowledgeHitTestingApi(Resource): class BedrockRetrievalApi(Resource): - # url : /retrieval + # this api is only for internal testing def post(self): parser = reqparse.RequestParser() parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") - parser.add_argument("query", nullable=False, required=True, type=str, ) + parser.add_argument( + "query", + nullable=False, + required=True, + type=str, + ) parser.add_argument("knowledge_id", nullable=False, required=True, type=str) args = parser.parse_args() # Call the knowledge retrieval service - result = ExternalDatasetServiceTest.knowledge_retrieval( + result = ExternalDatasetTestService.knowledge_retrieval( args["retrieval_setting"], args["query"], args["knowledge_id"] ) return result, 200 @@ -254,4 +259,5 @@ api.add_resource(ExternalDatasetCreateApi, "/datasets/external") api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api") api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/") api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api//use-check") -api.add_resource(BedrockRetrievalApi, "/datasets/retrieval") +# this api is only for internal test +api.add_resource(BedrockRetrievalApi, "/test/retrieval") diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 2cb4c6b886..987f94a350 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,11 +1,11 @@ from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -from core.rag.models.document import Document as RetrievalDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -65,9 +65,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): ) for external_document in external_documents: document = RetrievalDocument( - page_content=external_document.get("content"), - metadata=external_document.get("metadata"), - provider="external", + page_content=external_document.get("content"), + metadata=external_document.get("metadata"), + provider="external", ) document.metadata["score"] = external_document.get("score") document.metadata["title"] = external_document.get("title") @@ -94,7 +94,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): return str("\n".join([item.page_content for item in results])) else: - # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": @@ -147,7 +146,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") + document_context_list.append( + f"question:{segment.get_sign_content()} answer:{segment.answer}" + ) else: document_context_list.append(segment.get_sign_content()) if self.return_resource: diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 208fde4c07..02fe1d19bc 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,10 +1,10 @@ - import boto3 from configs import dify_config -class ExternalDatasetServiceTest: +class ExternalDatasetTestService: + # this service is only for internal testing @staticmethod def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str): # get bedrock client @@ -19,7 +19,10 @@ class ExternalDatasetServiceTest: response = client.retrieve( knowledgeBaseId=knowledge_id, retrievalConfiguration={ - "vectorSearchConfiguration": {"numberOfResults": retrieval_setting.get("top_k"), "overrideSearchType": "HYBRID"} + "vectorSearchConfiguration": { + "numberOfResults": retrieval_setting.get("top_k"), + "overrideSearchType": "HYBRID", + } }, retrievalQuery={"text": query}, ) @@ -30,7 +33,7 @@ class ExternalDatasetServiceTest: retrieval_results = response.get("retrievalResults") for retrieval_result in retrieval_results: # filter out results with score less than threshold - if retrieval_result.get("score") < retrieval_setting.get("score_threshold", .0): + if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0): continue result = { "metadata": retrieval_result.get("metadata"), @@ -39,6 +42,4 @@ class ExternalDatasetServiceTest: "content": retrieval_result.get("content").get("text"), } results.append(result) - return { - "records": results - } + return {"records": results}