feat:support baidu vector db (#9185)
parent
793205afc5
commit
2ec6ffe478
@ -0,0 +1,45 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduVectorDBConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration settings for Baidu Vector Database
|
||||||
|
"""
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
|
||||||
|
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
|
||||||
|
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
|
||||||
|
default=30000,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
|
||||||
|
description="Account for authenticating with the Baidu Vector Database",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
|
||||||
|
description="API key for authenticating with the Baidu Vector Database service",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||||
|
description="Name of the specific Baidu Vector Database to connect to",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
|
||||||
|
description="Number of shards for the Baidu Vector Database (default is 1)",
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||||
|
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||||
|
default=3,
|
||||||
|
)
|
||||||
@ -0,0 +1,272 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
from pymochow import MochowClient
|
||||||
|
from pymochow.auth.bce_credentials import BceCredentials
|
||||||
|
from pymochow.configuration import Configuration
|
||||||
|
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
|
||||||
|
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
|
||||||
|
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduConfig(BaseModel):
|
||||||
|
endpoint: str
|
||||||
|
connection_timeout_in_mills: int = 30 * 1000
|
||||||
|
account: str
|
||||||
|
api_key: str
|
||||||
|
database: str
|
||||||
|
index_type: str = "HNSW"
|
||||||
|
metric_type: str = "L2"
|
||||||
|
shard: int = 1
|
||||||
|
replicas: int = 3
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
if not values["endpoint"]:
|
||||||
|
raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required")
|
||||||
|
if not values["account"]:
|
||||||
|
raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required")
|
||||||
|
if not values["api_key"]:
|
||||||
|
raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required")
|
||||||
|
if not values["database"]:
|
||||||
|
raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduVector(BaseVector):
|
||||||
|
field_id: str = "id"
|
||||||
|
field_vector: str = "vector"
|
||||||
|
field_text: str = "text"
|
||||||
|
field_metadata: str = "metadata"
|
||||||
|
field_app_id: str = "app_id"
|
||||||
|
field_annotation_id: str = "annotation_id"
|
||||||
|
index_vector: str = "vector_idx"
|
||||||
|
|
||||||
|
def __init__(self, collection_name: str, config: BaiduConfig):
|
||||||
|
super().__init__(collection_name)
|
||||||
|
self._client_config = config
|
||||||
|
self._client = self._init_client(config)
|
||||||
|
self._db = self._init_database()
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return VectorType.BAIDU
|
||||||
|
|
||||||
|
def to_index_struct(self) -> dict:
|
||||||
|
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
||||||
|
|
||||||
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
self._create_table(len(embeddings[0]))
|
||||||
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
texts = [doc.page_content for doc in documents]
|
||||||
|
metadatas = [doc.metadata for doc in documents]
|
||||||
|
total_count = len(documents)
|
||||||
|
batch_size = 1000
|
||||||
|
|
||||||
|
# upsert texts and embeddings batch by batch
|
||||||
|
table = self._db.table(self._collection_name)
|
||||||
|
for start in range(0, total_count, batch_size):
|
||||||
|
end = min(start + batch_size, total_count)
|
||||||
|
rows = []
|
||||||
|
for i in range(start, end, 1):
|
||||||
|
row = Row(
|
||||||
|
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||||
|
vector=embeddings[i],
|
||||||
|
text=texts[i],
|
||||||
|
metadata=json.dumps(metadatas[i]),
|
||||||
|
app_id=metadatas[i].get("app_id", ""),
|
||||||
|
annotation_id=metadatas[i].get("annotation_id", ""),
|
||||||
|
)
|
||||||
|
rows.append(row)
|
||||||
|
table.upsert(rows=rows)
|
||||||
|
|
||||||
|
# rebuild vector index after upsert finished
|
||||||
|
table.rebuild_index(self.index_vector)
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
index = table.describe_index(self.index_vector)
|
||||||
|
if index.state == IndexState.NORMAL:
|
||||||
|
break
|
||||||
|
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
|
||||||
|
if res and res.code == 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
quoted_ids = [f"'{id}'" for id in ids]
|
||||||
|
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||||
|
|
||||||
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
|
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
|
||||||
|
|
||||||
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
anns = AnnSearch(
|
||||||
|
vector_field=self.field_vector,
|
||||||
|
vector_floats=query_vector,
|
||||||
|
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||||
|
)
|
||||||
|
res = self._db.table(self._collection_name).search(
|
||||||
|
anns=anns,
|
||||||
|
projections=[self.field_id, self.field_text, self.field_metadata],
|
||||||
|
retrieve_vector=True,
|
||||||
|
)
|
||||||
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
|
return self._get_search_res(res, score_threshold)
|
||||||
|
|
||||||
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
# baidu vector database doesn't support bm25 search on current version
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_search_res(self, res, score_threshold):
|
||||||
|
docs = []
|
||||||
|
for row in res.rows:
|
||||||
|
row_data = row.get("row", {})
|
||||||
|
meta = row_data.get(self.field_metadata)
|
||||||
|
if meta is not None:
|
||||||
|
meta = json.loads(meta)
|
||||||
|
score = row.get("score", 0.0)
|
||||||
|
if score > score_threshold:
|
||||||
|
meta["score"] = score
|
||||||
|
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
self._db.drop_table(table_name=self._collection_name)
|
||||||
|
|
||||||
|
def _init_client(self, config) -> MochowClient:
|
||||||
|
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
|
||||||
|
client = MochowClient(config)
|
||||||
|
return client
|
||||||
|
|
||||||
|
def _init_database(self):
|
||||||
|
exists = False
|
||||||
|
for db in self._client.list_databases():
|
||||||
|
if db.database_name == self._client_config.database:
|
||||||
|
exists = True
|
||||||
|
break
|
||||||
|
# Create database if not existed
|
||||||
|
if exists:
|
||||||
|
return self._client.database(self._client_config.database)
|
||||||
|
else:
|
||||||
|
return self._client.create_database(database_name=self._client_config.database)
|
||||||
|
|
||||||
|
def _table_existed(self) -> bool:
|
||||||
|
tables = self._db.list_table()
|
||||||
|
return any(table.table_name == self._collection_name for table in tables)
|
||||||
|
|
||||||
|
def _create_table(self, dimension: int) -> None:
|
||||||
|
# Try to grab distributed lock and create table
|
||||||
|
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||||
|
if redis_client.get(table_exist_cache_key):
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._table_existed():
|
||||||
|
return
|
||||||
|
|
||||||
|
self.delete()
|
||||||
|
|
||||||
|
# check IndexType and MetricType
|
||||||
|
index_type = None
|
||||||
|
for k, v in IndexType.__members__.items():
|
||||||
|
if k == self._client_config.index_type:
|
||||||
|
index_type = v
|
||||||
|
if index_type is None:
|
||||||
|
raise ValueError("unsupported index_type")
|
||||||
|
metric_type = None
|
||||||
|
for k, v in MetricType.__members__.items():
|
||||||
|
if k == self._client_config.metric_type:
|
||||||
|
metric_type = v
|
||||||
|
if metric_type is None:
|
||||||
|
raise ValueError("unsupported metric_type")
|
||||||
|
|
||||||
|
# Construct field schema
|
||||||
|
fields = []
|
||||||
|
fields.append(
|
||||||
|
Field(
|
||||||
|
self.field_id,
|
||||||
|
FieldType.STRING,
|
||||||
|
primary_key=True,
|
||||||
|
partition_key=True,
|
||||||
|
auto_increment=False,
|
||||||
|
not_null=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
|
||||||
|
fields.append(Field(self.field_app_id, FieldType.STRING))
|
||||||
|
fields.append(Field(self.field_annotation_id, FieldType.STRING))
|
||||||
|
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
|
||||||
|
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
|
||||||
|
|
||||||
|
# Construct vector index params
|
||||||
|
indexes = []
|
||||||
|
indexes.append(
|
||||||
|
VectorIndex(
|
||||||
|
index_name="vector_idx",
|
||||||
|
index_type=index_type,
|
||||||
|
field="vector",
|
||||||
|
metric_type=metric_type,
|
||||||
|
params=HNSWParams(m=16, efconstruction=200),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create table
|
||||||
|
self._db.create_table(
|
||||||
|
table_name=self._collection_name,
|
||||||
|
replication=self._client_config.replicas,
|
||||||
|
partition=Partition(partition_num=self._client_config.shard),
|
||||||
|
schema=Schema(fields=fields, indexes=indexes),
|
||||||
|
description="Table for Dify",
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
# Wait for table created
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
table = self._db.describe_table(self._collection_name)
|
||||||
|
if table.state == TableState.NORMAL:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduVectorFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
|
collection_name = class_prefix.lower()
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||||
|
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name))
|
||||||
|
|
||||||
|
return BaiduVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
config=BaiduConfig(
|
||||||
|
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
|
||||||
|
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
|
||||||
|
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
|
||||||
|
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
|
||||||
|
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
|
||||||
|
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
|
||||||
|
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||||
|
),
|
||||||
|
)
|
||||||
@ -0,0 +1,154 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
|
from pymochow import MochowClient
|
||||||
|
from pymochow.model.database import Database
|
||||||
|
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
|
||||||
|
from pymochow.model.schema import HNSWParams, VectorIndex
|
||||||
|
from pymochow.model.table import Table
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class MockBaiduVectorDBClass:
|
||||||
|
def mock_vector_db_client(
|
||||||
|
self,
|
||||||
|
config=None,
|
||||||
|
adapter: HTTPAdapter = None,
|
||||||
|
):
|
||||||
|
self._conn = None
|
||||||
|
self._config = None
|
||||||
|
|
||||||
|
def list_databases(self, config=None) -> list[Database]:
|
||||||
|
return [
|
||||||
|
Database(
|
||||||
|
conn=self._conn,
|
||||||
|
database_name="dify",
|
||||||
|
config=self._config,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def create_database(self, database_name: str, config=None) -> Database:
|
||||||
|
return Database(conn=self._conn, database_name=database_name, config=config)
|
||||||
|
|
||||||
|
def list_table(self, config=None) -> list[Table]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def drop_table(self, table_name: str, config=None):
|
||||||
|
return {"code": 0, "msg": "Success"}
|
||||||
|
|
||||||
|
def create_table(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
replication: int,
|
||||||
|
partition: int,
|
||||||
|
schema,
|
||||||
|
enable_dynamic_field=False,
|
||||||
|
description: str = "",
|
||||||
|
config=None,
|
||||||
|
) -> Table:
|
||||||
|
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
|
||||||
|
|
||||||
|
def describe_table(self, table_name: str, config=None) -> Table:
|
||||||
|
return Table(
|
||||||
|
self,
|
||||||
|
table_name,
|
||||||
|
3,
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
enable_dynamic_field=False,
|
||||||
|
description="table for dify",
|
||||||
|
config=config,
|
||||||
|
state=TableState.NORMAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upsert(self, rows, config=None):
|
||||||
|
return {"code": 0, "msg": "operation success", "affectedCount": 1}
|
||||||
|
|
||||||
|
def rebuild_index(self, index_name: str, config=None):
|
||||||
|
return {"code": 0, "msg": "Success"}
|
||||||
|
|
||||||
|
def describe_index(self, index_name: str, config=None):
|
||||||
|
return VectorIndex(
|
||||||
|
index_name=index_name,
|
||||||
|
index_type=IndexType.HNSW,
|
||||||
|
field="vector",
|
||||||
|
metric_type=MetricType.L2,
|
||||||
|
params=HNSWParams(m=16, efconstruction=200),
|
||||||
|
auto_build=False,
|
||||||
|
state=IndexState.NORMAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
primary_key,
|
||||||
|
partition_key=None,
|
||||||
|
projections=None,
|
||||||
|
retrieve_vector=False,
|
||||||
|
read_consistency=ReadConsistency.EVENTUAL,
|
||||||
|
config=None,
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"row": {
|
||||||
|
"id": "doc_id_001",
|
||||||
|
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||||
|
"text": "text",
|
||||||
|
"metadata": {"doc_id": "doc_id_001"},
|
||||||
|
},
|
||||||
|
"code": 0,
|
||||||
|
"msg": "Success",
|
||||||
|
}
|
||||||
|
|
||||||
|
def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
|
||||||
|
return {"code": 0, "msg": "Success"}
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
anns,
|
||||||
|
partition_key=None,
|
||||||
|
projections=None,
|
||||||
|
retrieve_vector=False,
|
||||||
|
read_consistency=ReadConsistency.EVENTUAL,
|
||||||
|
config=None,
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"rows": [
|
||||||
|
{
|
||||||
|
"row": {
|
||||||
|
"id": "doc_id_001",
|
||||||
|
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||||
|
"text": "text",
|
||||||
|
"metadata": {"doc_id": "doc_id_001"},
|
||||||
|
},
|
||||||
|
"distance": 0.1,
|
||||||
|
"score": 0.5,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"code": 0,
|
||||||
|
"msg": "Success",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||||
|
if MOCK:
|
||||||
|
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
|
||||||
|
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
|
||||||
|
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
|
||||||
|
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
|
||||||
|
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
|
||||||
|
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
|
||||||
|
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
|
||||||
|
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
|
||||||
|
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
|
||||||
|
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
|
||||||
|
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
|
||||||
|
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
if MOCK:
|
||||||
|
monkeypatch.undo()
|
||||||
@ -0,0 +1,36 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
|
||||||
|
from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
|
||||||
|
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.list_databases.return_value = [{"name": "test"}]
|
||||||
|
|
||||||
|
|
||||||
|
class BaiduVectorTest(AbstractVectorTest):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.vector = BaiduVector(
|
||||||
|
"dify",
|
||||||
|
BaiduConfig(
|
||||||
|
endpoint="http://127.0.0.1:5287",
|
||||||
|
account="root",
|
||||||
|
api_key="dify",
|
||||||
|
database="dify",
|
||||||
|
shard=1,
|
||||||
|
replicas=3,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_by_vector(self):
|
||||||
|
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||||
|
assert len(hits_by_vector) == 1
|
||||||
|
|
||||||
|
def search_by_full_text(self):
|
||||||
|
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||||
|
assert len(hits_by_full_text) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock):
|
||||||
|
BaiduVectorTest().run_all_tests()
|
||||||
Loading…
Reference in New Issue