add tests

pull/22774/head
xiaozhiqing.xzq 7 months ago
parent f4153e9c45
commit a143a4fde4

@ -248,7 +248,7 @@ class TableStoreVector(BaseVector):
limit=1000,
get_total_count=False,
)
rows = []
rows: list[str] = []
next_token = None
while True:
if next_token is not None:
@ -264,7 +264,7 @@ class TableStoreVector(BaseVector):
)
if search_response is not None:
rows.extend(row[0][0][1] for row in search_response.rows)
rows.extend([row[0][0][1] for row in search_response.rows])
if search_response is None or search_response.next_token == b"":
break
@ -274,7 +274,7 @@ class TableStoreVector(BaseVector):
return rows
def _search_by_vector(
self, query_vector: list[float], document_ids_filter: list[str], top_k: int, score_threshold: float
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
knn_vector_query = tablestore.KnnVectorQuery(
field_name=Field.VECTOR.value,
@ -308,7 +308,7 @@ class TableStoreVector(BaseVector):
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def _search_by_full_text(self, query: str, document_ids_filter: list[str], top_k: int) -> list[Document]:
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
bool_query = tablestore.BoolQuery()
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))

@ -1,4 +1,7 @@
import os
import uuid
import tablestore
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
TableStoreConfig,
@ -7,6 +10,8 @@ from core.rag.datasource.vdb.tablestore.tablestore_vector import (
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
setup_mock_redis,
get_example_document,
get_example_text,
)
@ -29,6 +34,49 @@ class TableStoreVectorTest(AbstractVectorTest):
assert len(ids) == 1
assert ids[0] == self.example_doc_id
def create_vector(self):
self.vector.create(
texts=[get_example_document(doc_id=self.example_doc_id)],
embeddings=[self.example_embedding],
)
while True:
search_response = self.vector._tablestore_client.search(
table_name=self.vector._table_name,
index_name=self.vector._index_name,
search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0),
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
if search_response.total_count == 1:
break
def search_by_vector(self):
super().search_by_vector()
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id])
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
assert docs[0].metadata["score"] > 0
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())])
assert len(docs) == 0
def search_by_full_text(self):
super().search_by_full_text()
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
assert not hasattr(docs[0], "score")
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
assert len(docs) == 0
def run_all_tests(self):
try:
self.vector.delete()
except Exception:
pass
return super().run_all_tests()
def test_tablestore_vector(setup_mock_redis):
TableStoreVectorTest().run_all_tests()

Loading…
Cancel
Save