diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index d79bac8f67..55326fd60f 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -248,7 +248,7 @@ class TableStoreVector(BaseVector): limit=1000, get_total_count=False, ) - rows = [] + rows: list[str] = [] next_token = None while True: if next_token is not None: @@ -264,7 +264,7 @@ class TableStoreVector(BaseVector): ) if search_response is not None: - rows.extend(row[0][0][1] for row in search_response.rows) + rows.extend([row[0][0][1] for row in search_response.rows]) if search_response is None or search_response.next_token == b"": break @@ -274,7 +274,7 @@ class TableStoreVector(BaseVector): return rows def _search_by_vector( - self, query_vector: list[float], document_ids_filter: list[str], top_k: int, score_threshold: float + self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float ) -> list[Document]: knn_vector_query = tablestore.KnnVectorQuery( field_name=Field.VECTOR.value, @@ -308,7 +308,7 @@ class TableStoreVector(BaseVector): documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents - def _search_by_full_text(self, query: str, document_ids_filter: list[str], top_k: int) -> list[Document]: + def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: bool_query = tablestore.BoolQuery() bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) diff --git a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py index da890d0b7c..898af53ddf 100644 --- a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py +++ b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py @@ -1,4 +1,7 @@ import os +import uuid + +import tablestore from core.rag.datasource.vdb.tablestore.tablestore_vector import ( TableStoreConfig, @@ -7,6 +10,8 @@ from core.rag.datasource.vdb.tablestore.tablestore_vector import ( from tests.integration_tests.vdb.test_vector_store import ( AbstractVectorTest, setup_mock_redis, + get_example_document, + get_example_text, ) @@ -29,6 +34,49 @@ class TableStoreVectorTest(AbstractVectorTest): assert len(ids) == 1 assert ids[0] == self.example_doc_id + def create_vector(self): + self.vector.create( + texts=[get_example_document(doc_id=self.example_doc_id)], + embeddings=[self.example_embedding], + ) + while True: + search_response = self.vector._tablestore_client.search( + table_name=self.vector._table_name, + index_name=self.vector._index_name, + search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0), + columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), + ) + if search_response.total_count == 1: + break + + def search_by_vector(self): + super().search_by_vector() + docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id]) + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == self.example_doc_id + assert docs[0].metadata["score"] > 0 + + docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())]) + assert len(docs) == 0 + + def search_by_full_text(self): + super().search_by_full_text() + docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id]) + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == self.example_doc_id + assert not hasattr(docs[0], "score") + + docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())]) + assert len(docs) == 0 + + def run_all_tests(self): + try: + self.vector.delete() + except Exception: + pass + + return super().run_all_tests() + def test_tablestore_vector(setup_mock_redis): TableStoreVectorTest().run_all_tests()