|
|
|
|
@ -1,10 +1,12 @@
|
|
|
|
|
import copy
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
from opensearchpy import OpenSearch
|
|
|
|
|
from pydantic import BaseModel, model_validator
|
|
|
|
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
from core.rag.datasource.vdb.field import Field
|
|
|
|
|
@ -77,31 +79,74 @@ class LindormVectorStore(BaseVector):
|
|
|
|
|
def refresh(self):
|
|
|
|
|
self._client.indices.refresh(index=self._collection_name)
|
|
|
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
|
|
actions = []
|
|
|
|
|
def add_texts(
|
|
|
|
|
self,
|
|
|
|
|
documents: list[Document],
|
|
|
|
|
embeddings: list[list[float]],
|
|
|
|
|
batch_size: int = 64,
|
|
|
|
|
timeout: int = 60,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
logger.info(f"Total documents to add: {len(documents)}")
|
|
|
|
|
uuids = self._get_uuids(documents)
|
|
|
|
|
for i in range(len(documents)):
|
|
|
|
|
action_header = {
|
|
|
|
|
"index": {
|
|
|
|
|
"_index": self.collection_name.lower(),
|
|
|
|
|
"_id": uuids[i],
|
|
|
|
|
|
|
|
|
|
total_docs = len(documents)
|
|
|
|
|
num_batches = (total_docs + batch_size - 1) // batch_size
|
|
|
|
|
|
|
|
|
|
@retry(
|
|
|
|
|
stop=stop_after_attempt(3),
|
|
|
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
|
|
|
)
|
|
|
|
|
def _bulk_with_retry(actions):
|
|
|
|
|
try:
|
|
|
|
|
response = self._client.bulk(actions, timeout=timeout)
|
|
|
|
|
if response["errors"]:
|
|
|
|
|
error_items = [item for item in response["items"] if "error" in item["index"]]
|
|
|
|
|
error_msg = f"Bulk indexing had {len(error_items)} errors"
|
|
|
|
|
logger.exception(error_msg)
|
|
|
|
|
raise Exception(error_msg)
|
|
|
|
|
return response
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.exception("Bulk indexing error")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
for batch_num in range(num_batches):
|
|
|
|
|
start_idx = batch_num * batch_size
|
|
|
|
|
end_idx = min((batch_num + 1) * batch_size, total_docs)
|
|
|
|
|
|
|
|
|
|
actions = []
|
|
|
|
|
for i in range(start_idx, end_idx):
|
|
|
|
|
action_header = {
|
|
|
|
|
"index": {
|
|
|
|
|
"_index": self.collection_name.lower(),
|
|
|
|
|
"_id": uuids[i],
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
action_values: dict[str, Any] = {
|
|
|
|
|
Field.CONTENT_KEY.value: documents[i].page_content,
|
|
|
|
|
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
|
|
|
|
Field.METADATA_KEY.value: documents[i].metadata,
|
|
|
|
|
}
|
|
|
|
|
if self._using_ugc:
|
|
|
|
|
action_header["index"]["routing"] = self._routing
|
|
|
|
|
if self._routing_field is not None:
|
|
|
|
|
action_values[self._routing_field] = self._routing
|
|
|
|
|
actions.append(action_header)
|
|
|
|
|
actions.append(action_values)
|
|
|
|
|
response = self._client.bulk(actions)
|
|
|
|
|
if response["errors"]:
|
|
|
|
|
for item in response["items"]:
|
|
|
|
|
print(f"{item['index']['status']}: {item['index']['error']['type']}")
|
|
|
|
|
action_values: dict[str, Any] = {
|
|
|
|
|
Field.CONTENT_KEY.value: documents[i].page_content,
|
|
|
|
|
Field.VECTOR.value: embeddings[i],
|
|
|
|
|
Field.METADATA_KEY.value: documents[i].metadata,
|
|
|
|
|
}
|
|
|
|
|
if self._using_ugc:
|
|
|
|
|
action_header["index"]["routing"] = self._routing
|
|
|
|
|
if self._routing_field is not None:
|
|
|
|
|
action_values[self._routing_field] = self._routing
|
|
|
|
|
|
|
|
|
|
actions.append(action_header)
|
|
|
|
|
actions.append(action_values)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
_bulk_with_retry(actions)
|
|
|
|
|
logger.info(f"Successfully processed batch {batch_num + 1}")
|
|
|
|
|
# simple latency to avoid too many requests in a short time
|
|
|
|
|
if batch_num < num_batches - 1:
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.exception(f"Failed to process batch {batch_num + 1}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
|
|
|
|
query: dict[str, Any] = {
|
|
|
|
|
@ -130,7 +175,6 @@ class LindormVectorStore(BaseVector):
|
|
|
|
|
if self._using_ugc:
|
|
|
|
|
params["routing"] = self._routing
|
|
|
|
|
self._client.delete(index=self._collection_name, id=id, params=params)
|
|
|
|
|
self.refresh()
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
|
|
|
|
|
|
|
|
|
|