diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index cfa59165a4..d79bac8f67 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -131,9 +131,8 @@ class TableStoreVector(BaseVector): filtered_list = None if document_ids_filter: filtered_list = ["document_id=" + item for item in document_ids_filter] - score_threshold = float(kwargs.get("score_threshold") or 0.0) - return self._search_by_full_text(query, filtered_list, top_k, score_threshold) + return self._search_by_full_text(query, filtered_list, top_k) def delete(self) -> None: self._delete_table_if_exist() @@ -294,12 +293,22 @@ class TableStoreVector(BaseVector): search_query=search_query, columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), ) + documents = [] + for search_hit in search_response.search_hits: + if search_hit.score > score_threshold: + metadata = json.loads(search_hit.row[1][0][1]) + metadata["score"] = search_hit.score + documents.append( + Document( + page_content=search_hit.row[1][2][1], + vector=json.loads(search_hit.row[1][3][1]), + metadata=metadata, + ) + ) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) + return documents - return self._to_query_result(search_response, score_threshold) - - def _search_by_full_text( - self, query: str, document_ids_filter: list[str], top_k: int, score_threshold: float - ) -> list[Document]: + def _search_by_full_text(self, query: str, document_ids_filter: list[str], top_k: int) -> list[Document]: bool_query = tablestore.BoolQuery() bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) @@ -318,23 +327,15 @@ class TableStoreVector(BaseVector): columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), ) - return self._to_query_result(search_response, score_threshold) - - @staticmethod - def _to_query_result(search_response: tablestore.SearchResponse, score_threshold: float) -> list[Document]: documents = [] for search_hit in search_response.search_hits: - if search_hit.score > score_threshold: - metadata = json.loads(search_hit.row[1][0][1]) - metadata["score"] = search_hit.score - documents.append( - Document( - page_content=search_hit.row[1][2][1], - vector=json.loads(search_hit.row[1][3][1]), - metadata=metadata, - ) + documents.append( + Document( + page_content=search_hit.row[1][2][1], + vector=json.loads(search_hit.row[1][3][1]), + metadata=json.loads(search_hit.row[1][0][1]), ) - documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) + ) return documents