diff --git a/api/controllers/console/workspace/account_ext.py b/api/controllers/console/workspace/account_ext.py index 8e2e811524..42c922621b 100644 --- a/api/controllers/console/workspace/account_ext.py +++ b/api/controllers/console/workspace/account_ext.py @@ -95,16 +95,18 @@ class FullSearchTextApi(Resource): help='List of names' ) parser.add_argument("query_text", type=str, required=True, location="json") + parser.add_argument("file_ids", type=str, required=True, location="json") args = parser.parse_args() dataset_names = args.dataset_names query_text = args.query_text - + file_ids = args.file_ids current_user = flask_login.current_user tenant = current_user.current_tenant search_datas = DocumentExtService.get_full_search_data( dataset_names=dataset_names, tenant_id=tenant.id, - query_text=query_text + query_text=query_text, + file_ids=file_ids ) return search_datas diff --git a/api/services/ext/dataset_ext_service.py b/api/services/ext/dataset_ext_service.py index 3ac217e685..148b98bb8f 100644 --- a/api/services/ext/dataset_ext_service.py +++ b/api/services/ext/dataset_ext_service.py @@ -202,7 +202,13 @@ class DocumentExtService: return next_segment - def get_full_search_data(dataset_names: list[str], tenant_id : str, query_text: str): + def get_full_search_data(dataset_names: list[str], + tenant_id : str, + query_text: str, + file_ids: str) -> list[dict]: + + if not file_ids: + return [] datasets = db.session.query(Dataset).filter(Dataset.name.in_(dataset_names),Dataset.tenant_id == tenant_id).all() dataset_ids = [dataset.id for dataset in datasets] @@ -214,11 +220,12 @@ class DocumentExtService: search_data = { "title": segment.document_name, "content": segment.segment_content, - "doc_metadata": segment.metadata, + "doc_metadata": segment.doc_metadata, "query": query_text } search_datas.append(search_data) + search_datas = DocumentExtService.filter_by_file_ids(search_datas=search_datas, file_ids=file_ids) return search_datas def get_full_search_segments(dataset_ids: list[str], query_text: str): @@ -271,3 +278,10 @@ class DocumentExtService: else: fetch_segments.append(segment_list[1]) return fetch_segments + + def filter_by_file_ids(search_datas: list[dict], file_ids: str) -> list[dict]: + file_id_list = file_ids.split(",") + return [ + item for item in search_datas + if item.get("doc_metadata", {}).get("file_id") in file_ids + ]