Merge branch 'feat/external-knowledge-api' of github.com:langgenius/dify into feat/external-knowledge-api

feat/external-knowledge-api
Yi 2 years ago
commit 1597f34471

@ -14,16 +14,11 @@ class TestExternalApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument(
"top_k", "retrieval_setting",
nullable=False, nullable=False,
required=True, required=True,
type=int, type=dict,
) location="json"
parser.add_argument(
"score_threshold",
nullable=False,
required=True,
type=float,
) )
parser.add_argument( parser.add_argument(
"query", "query",
@ -32,14 +27,14 @@ class TestExternalApi(Resource):
type=str, type=str,
) )
parser.add_argument( parser.add_argument(
"external_knowledge_id", "knowledge_id",
nullable=False, nullable=False,
required=True, required=True,
type=str, type=str,
) )
args = parser.parse_args() args = parser.parse_args()
result = ExternalDatasetService.test_external_knowledge_retrieval( result = ExternalDatasetService.test_external_knowledge_retrieval(
args["top_k"], args["score_threshold"], args["query"], args["external_knowledge_id"] args["retrieval_setting"], args["query"], args["knowledge_id"]
) )
return result, 200 return result, 200

@ -283,22 +283,28 @@ class ExternalDatasetService:
if settings.get("api_key"): if settings.get("api_key"):
headers["Authorization"] = f"Bearer {settings.get('api_key')}" headers["Authorization"] = f"Bearer {settings.get('api_key')}"
external_retrieval_parameters["query"] = query request_params = {
external_retrieval_parameters["external_knowledge_id"] = external_knowledge_binding.external_knowledge_id "retrieval_setting": {
"top_k": external_retrieval_parameters.get("top_k"),
"score_threshold": external_retrieval_parameters.get("score_threshold"),
},
"query": query,
"knowledge_id": external_knowledge_binding.external_knowledge_id,
}
external_knowledge_api_setting = { external_knowledge_api_setting = {
"url": f"{settings.get('endpoint')}/dify/external-knowledge/retrieval-documents", "url": f"{settings.get('endpoint')}/dify/external-knowledge/retrieval-documents",
"request_method": "post", "request_method": "post",
"headers": headers, "headers": headers,
"params": external_retrieval_parameters, "params": request_params,
} }
response = ExternalDatasetService.process_external_api(ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None) response = ExternalDatasetService.process_external_api(ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None)
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json().get("records", [])
return [] return []
@staticmethod @staticmethod
def test_external_knowledge_retrieval(top_k: int, score_threshold: float, query: str, external_knowledge_id: str): def test_external_knowledge_retrieval(retrieval_setting: dict, query: str, external_knowledge_id: str):
client = boto3.client( client = boto3.client(
"bedrock-agent-runtime", "bedrock-agent-runtime",
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY, aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
@ -308,7 +314,7 @@ class ExternalDatasetService:
response = client.retrieve( response = client.retrieve(
knowledgeBaseId=external_knowledge_id, knowledgeBaseId=external_knowledge_id,
retrievalConfiguration={ retrievalConfiguration={
"vectorSearchConfiguration": {"numberOfResults": top_k, "overrideSearchType": "HYBRID"} "vectorSearchConfiguration": {"numberOfResults": retrieval_setting.get("top_k"), "overrideSearchType": "HYBRID"}
}, },
retrievalQuery={"text": query}, retrievalQuery={"text": query},
) )
@ -317,7 +323,7 @@ class ExternalDatasetService:
if response.get("retrievalResults"): if response.get("retrievalResults"):
retrieval_results = response.get("retrievalResults") retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results: for retrieval_result in retrieval_results:
if retrieval_result.get("score") < score_threshold: if retrieval_result.get("score") < retrieval_setting.get("score_threshold", .0):
continue continue
result = { result = {
"metadata": retrieval_result.get("metadata"), "metadata": retrieval_result.get("metadata"),
@ -326,4 +332,6 @@ class ExternalDatasetService:
"content": retrieval_result.get("content").get("text"), "content": retrieval_result.get("content").get("text"),
} }
results.append(result) results.append(result)
return results return {
"records": results
}

Loading…
Cancel
Save