|
|
|
@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
import time
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from typing import Any, Optional
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
|
|
@ -13,6 +15,8 @@ from extensions.ext_database import db
|
|
|
|
from extensions.ext_redis import redis_client
|
|
|
|
from extensions.ext_redis import redis_client
|
|
|
|
from models.dataset import Dataset, Whitelist
|
|
|
|
from models.dataset import Dataset, Whitelist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AbstractVectorFactory(ABC):
|
|
|
|
class AbstractVectorFactory(ABC):
|
|
|
|
@abstractmethod
|
|
|
|
@abstractmethod
|
|
|
|
@ -173,8 +177,20 @@ class Vector:
|
|
|
|
|
|
|
|
|
|
|
|
def create(self, texts: Optional[list] = None, **kwargs):
|
|
|
|
def create(self, texts: Optional[list] = None, **kwargs):
|
|
|
|
if texts:
|
|
|
|
if texts:
|
|
|
|
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
|
|
|
|
start = time.time()
|
|
|
|
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
|
|
|
|
logger.info(f"start embedding {len(texts)} texts {start}")
|
|
|
|
|
|
|
|
batch_size = 1000
|
|
|
|
|
|
|
|
total_batches = len(texts) + batch_size - 1
|
|
|
|
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
|
|
|
|
|
batch = texts[i : i + batch_size]
|
|
|
|
|
|
|
|
batch_start = time.time()
|
|
|
|
|
|
|
|
logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)")
|
|
|
|
|
|
|
|
batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
|
|
|
|
|
|
|
logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s")
|
|
|
|
|
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], **kwargs):
|
|
|
|
def add_texts(self, documents: list[Document], **kwargs):
|
|
|
|
if kwargs.get("duplicate_check", False):
|
|
|
|
if kwargs.get("duplicate_check", False):
|
|
|
|
|