parent
5397799aac
commit
b1fd1b3ab3
@ -0,0 +1,114 @@
|
|||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import Document, BaseRetriever
|
||||||
|
from langchain.vectorstores import VectorStore, milvus
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from core.index.base import BaseIndex
|
||||||
|
from core.index.vector_index.base import BaseVectorIndex
|
||||||
|
from core.vector_store.milvus_vector_store import MilvusVectorStore
|
||||||
|
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusConfig(BaseModel):
|
||||||
|
endpoint: str
|
||||||
|
user: str
|
||||||
|
password: str
|
||||||
|
batch_size: int = 100
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
if not values['endpoint']:
|
||||||
|
raise ValueError("config MILVUS_ENDPOINT is required")
|
||||||
|
if not values['user']:
|
||||||
|
raise ValueError("config MILVUS_USER is required")
|
||||||
|
if not values['password']:
|
||||||
|
raise ValueError("config MILVUS_PASSWORD is required")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusVectorIndex(BaseVectorIndex):
|
||||||
|
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
|
||||||
|
super().__init__(dataset, embeddings)
|
||||||
|
self._client = self._init_client(config)
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return 'milvus'
|
||||||
|
|
||||||
|
def get_index_name(self, dataset: Dataset) -> str:
|
||||||
|
if self.dataset.index_struct_dict:
|
||||||
|
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
if not class_prefix.endswith('_Node'):
|
||||||
|
# original class_prefix
|
||||||
|
class_prefix += '_Node'
|
||||||
|
|
||||||
|
return class_prefix
|
||||||
|
|
||||||
|
dataset_id = dataset.id
|
||||||
|
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||||
|
|
||||||
|
|
||||||
|
def to_index_struct(self) -> dict:
|
||||||
|
return {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||||
|
}
|
||||||
|
|
||||||
|
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||||
|
uuids = self._get_uuids(texts)
|
||||||
|
self._vector_store = WeaviateVectorStore.from_documents(
|
||||||
|
texts,
|
||||||
|
self._embeddings,
|
||||||
|
client=self._client,
|
||||||
|
index_name=self.get_index_name(self.dataset),
|
||||||
|
uuids=uuids,
|
||||||
|
by_text=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _get_vector_store(self) -> VectorStore:
|
||||||
|
"""Only for created index."""
|
||||||
|
if self._vector_store:
|
||||||
|
return self._vector_store
|
||||||
|
|
||||||
|
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||||
|
if self._is_origin():
|
||||||
|
attributes = ['doc_id']
|
||||||
|
|
||||||
|
return WeaviateVectorStore(
|
||||||
|
client=self._client,
|
||||||
|
index_name=self.get_index_name(self.dataset),
|
||||||
|
text_key='text',
|
||||||
|
embedding=self._embeddings,
|
||||||
|
attributes=attributes,
|
||||||
|
by_text=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_vector_store_class(self) -> type:
|
||||||
|
return MilvusVectorStore
|
||||||
|
|
||||||
|
def delete_by_document_id(self, document_id: str):
|
||||||
|
if self._is_origin():
|
||||||
|
self.recreate_dataset(self.dataset)
|
||||||
|
return
|
||||||
|
|
||||||
|
vector_store = self._get_vector_store()
|
||||||
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
|
vector_store.del_texts({
|
||||||
|
"operator": "Equal",
|
||||||
|
"path": ["document_id"],
|
||||||
|
"valueText": document_id
|
||||||
|
})
|
||||||
|
|
||||||
|
def _is_origin(self):
|
||||||
|
if self.dataset.index_struct_dict:
|
||||||
|
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
if not class_prefix.endswith('_Node'):
|
||||||
|
# original class_prefix
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,38 @@
|
|||||||
|
from langchain.vectorstores import Milvus
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusVectorStore(Milvus):
|
||||||
|
def del_texts(self, where_filter: dict):
|
||||||
|
if not where_filter:
|
||||||
|
raise ValueError('where_filter must not be empty')
|
||||||
|
|
||||||
|
self._client.batch.delete_objects(
|
||||||
|
class_name=self._index_name,
|
||||||
|
where=where_filter,
|
||||||
|
output='minimal'
|
||||||
|
)
|
||||||
|
|
||||||
|
def del_text(self, uuid: str) -> None:
|
||||||
|
self._client.data_object.delete(
|
||||||
|
uuid,
|
||||||
|
class_name=self._index_name
|
||||||
|
)
|
||||||
|
|
||||||
|
def text_exists(self, uuid: str) -> bool:
|
||||||
|
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
|
||||||
|
"path": ["doc_id"],
|
||||||
|
"operator": "Equal",
|
||||||
|
"valueText": uuid,
|
||||||
|
}).with_limit(1).do()
|
||||||
|
|
||||||
|
if "errors" in result:
|
||||||
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
|
|
||||||
|
entries = result["data"]["Get"][self._index_name]
|
||||||
|
if len(entries) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
self._client.schema.delete_class(self._index_name)
|
||||||
Loading…
Reference in New Issue