parent
1dee5de9b4
commit
3241e4015b
@ -0,0 +1,43 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from langchain.document_loaders import TextLoader, Docx2txtLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.data_loader.loader.csv import CSVLoader
|
||||
from core.data_loader.loader.excel import ExcelLoader
|
||||
from core.data_loader.loader.html import HTMLLoader
|
||||
from core.data_loader.loader.markdown import MarkdownLoader
|
||||
from core.data_loader.loader.pdf import PdfLoader
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
if input_file.suffix == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
elif input_file.suffix == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif input_file.suffix in ['.md', '.markdown']:
|
||||
loader = MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif input_file.suffix in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif input_file.suffix == '.docx':
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif input_file.suffix == '.csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# txt
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
|
||||
@ -0,0 +1,67 @@
|
||||
import logging
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from langchain.document_loaders import CSVLoader as LCCSVLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
|
||||
from models.dataset import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CSVLoader(LCCSVLoader):
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[Dict] = None,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.source_column = source_column
|
||||
self.encoding = encoding
|
||||
self.csv_args = csv_args or {}
|
||||
self.autodetect_encoding = autodetect_encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self.autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self.file_path)
|
||||
for encoding in detected_encodings:
|
||||
logger.debug("Trying encoding: ", encoding.encoding)
|
||||
try:
|
||||
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self.file_path}") from e
|
||||
|
||||
return docs
|
||||
|
||||
def _read_from_file(self, csvfile):
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
|
||||
try:
|
||||
source = (
|
||||
row[self.source_column]
|
||||
if self.source_column is not None
|
||||
else ''
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Source column '{self.source_column}' not found in CSV file."
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@ -0,0 +1,43 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from openpyxl.reader.excel import load_workbook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExcelLoader(BaseLoader):
|
||||
"""Load xlxs files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
data = []
|
||||
keys = []
|
||||
wb = load_workbook(filename=self._file_path, read_only=True)
|
||||
# loop over all sheets
|
||||
for sheet in wb:
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
if all(v is None for v in row):
|
||||
continue
|
||||
if keys == []:
|
||||
keys = list(map(str, row))
|
||||
else:
|
||||
row_dict = dict(zip(keys, row))
|
||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||
data.append(json.dumps(row_dict, ensure_ascii=False))
|
||||
|
||||
return [Document(page_content='\n\n'.join(data))]
|
||||
@ -0,0 +1,35 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTMLLoader(BaseLoader):
|
||||
"""Load html files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
return [Document(page_content=self._load_as_text())]
|
||||
|
||||
def _load_as_text(self) -> str:
|
||||
with open(self._file_path, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, 'html.parser')
|
||||
text = soup.get_text()
|
||||
text = text.strip() if text else ''
|
||||
|
||||
return text
|
||||
@ -0,0 +1,134 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Tuple, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarkdownLoader(BaseLoader):
|
||||
"""Load md files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
|
||||
remove_hyperlinks: Whether to remove hyperlinks from the text.
|
||||
|
||||
remove_images: Whether to remove images from the text.
|
||||
|
||||
encoding: File encoding to use. If `None`, the file will be loaded
|
||||
with the default system encoding.
|
||||
|
||||
autodetect_encoding: Whether to try to autodetect the file encoding
|
||||
if the specified encoding fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = True,
|
||||
remove_images: bool = True,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._remove_hyperlinks = remove_hyperlinks
|
||||
self._remove_images = remove_images
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
tups = self.parse_tups(self._file_path)
|
||||
documents = []
|
||||
for header, value in tups:
|
||||
value = value.strip()
|
||||
if header is None:
|
||||
documents.append(Document(page_content=value))
|
||||
else:
|
||||
documents.append(Document(page_content=f"\n\n{header}\n{value}"))
|
||||
|
||||
return documents
|
||||
|
||||
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
|
||||
"""Convert a markdown file to a dictionary.
|
||||
|
||||
The keys are the headers and the values are the text under each header.
|
||||
|
||||
"""
|
||||
markdown_tups: List[Tuple[Optional[str], str]] = []
|
||||
lines = markdown_text.split("\n")
|
||||
|
||||
current_header = None
|
||||
current_text = ""
|
||||
|
||||
for line in lines:
|
||||
header_match = re.match(r"^#+\s", line)
|
||||
if header_match:
|
||||
if current_header is not None:
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
current_header = line
|
||||
current_text = ""
|
||||
else:
|
||||
current_text += line + "\n"
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
if current_header is not None:
|
||||
# pass linting, assert keys are defined
|
||||
markdown_tups = [
|
||||
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
|
||||
for key, value in markdown_tups
|
||||
]
|
||||
else:
|
||||
markdown_tups = [
|
||||
(key, re.sub("\n", "", value)) for key, value in markdown_tups
|
||||
]
|
||||
|
||||
return markdown_tups
|
||||
|
||||
def remove_images(self, content: str) -> str:
|
||||
"""Get a dictionary of a markdown file from its path."""
|
||||
pattern = r"!{1}\[\[(.*)\]\]"
|
||||
content = re.sub(pattern, "", content)
|
||||
return content
|
||||
|
||||
def remove_hyperlinks(self, content: str) -> str:
|
||||
"""Get a dictionary of a markdown file from its path."""
|
||||
pattern = r"\[(.*?)\]\((.*?)\)"
|
||||
content = re.sub(pattern, r"\1", content)
|
||||
return content
|
||||
|
||||
def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
|
||||
"""Parse file into tuples."""
|
||||
content = ""
|
||||
try:
|
||||
with open(filepath, "r", encoding=self._encoding) as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(filepath)
|
||||
for encoding in detected_encodings:
|
||||
logger.debug("Trying encoding: ", encoding.encoding)
|
||||
try:
|
||||
with open(filepath, encoding=encoding.encoding) as f:
|
||||
content = f.read()
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {filepath}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading {filepath}") from e
|
||||
|
||||
if self._remove_hyperlinks:
|
||||
content = self.remove_hyperlinks(content)
|
||||
|
||||
if self._remove_images:
|
||||
content = self.remove_images(content)
|
||||
|
||||
return self.markdown_to_tups(content)
|
||||
@ -0,0 +1,55 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.document_loaders import PyPDFium2Loader
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfLoader(BaseLoader):
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
upload_file: Optional[UploadFile] = None
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._upload_file = upload_file
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._upload_file:
|
||||
if self._upload_file.hash:
|
||||
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
|
||||
+ self._upload_file.hash + '.0625.plaintext'
|
||||
try:
|
||||
text = storage.load(plaintext_file_key).decode('utf-8')
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
documents = PyPDFium2Loader(file_path=self._file_path).load()
|
||||
text_list = []
|
||||
for document in documents:
|
||||
text_list.append(document.page_content)
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
|
||||
return documents
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
from llama_index.docstore.types import BaseDocumentStore
|
||||
from llama_index.schema import BaseDocument
|
||||
|
||||
|
||||
class EmptyDocumentStore(BaseDocumentStore):
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
|
||||
return cls()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def docs(self) -> Dict[str, BaseDocument]:
|
||||
return {}
|
||||
|
||||
def add_documents(
|
||||
self, docs: Sequence[BaseDocument], allow_update: bool = True
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def document_exists(self, doc_id: str) -> bool:
|
||||
"""Check if document exists."""
|
||||
return False
|
||||
|
||||
def get_document(
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
) -> Optional[BaseDocument]:
|
||||
return None
|
||||
|
||||
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
|
||||
pass
|
||||
|
||||
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
|
||||
"""Set the hash for a given doc_id."""
|
||||
pass
|
||||
|
||||
def get_document_hash(self, doc_id: str) -> Optional[str]:
|
||||
"""Get the stored hash for a document, if it exists."""
|
||||
return None
|
||||
|
||||
def update_docstore(self, other: "BaseDocumentStore") -> None:
|
||||
"""Update docstore.
|
||||
|
||||
Args:
|
||||
other (BaseDocumentStore): docstore to update from
|
||||
|
||||
"""
|
||||
self.add_documents(list(other.docs.values()))
|
||||
@ -0,0 +1,72 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
|
||||
|
||||
class CacheEmbedding(Embeddings):
|
||||
def __init__(self, embeddings: Embeddings):
|
||||
self._embeddings = embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
# use doc embedding cache or store if not exists
|
||||
text_embeddings = []
|
||||
embedding_queue_texts = []
|
||||
for text in texts:
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
|
||||
if embedding:
|
||||
text_embeddings.append(embedding.get_embedding())
|
||||
else:
|
||||
embedding_queue_texts.append(text)
|
||||
|
||||
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
|
||||
|
||||
i = 0
|
||||
for text in embedding_queue_texts:
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
||||
try:
|
||||
embedding = Embedding(hash=hash)
|
||||
embedding.set_embedding(embedding_results[i])
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
continue
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
continue
|
||||
|
||||
i += 1
|
||||
|
||||
text_embeddings.extend(embedding_results)
|
||||
return text_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
|
||||
if embedding:
|
||||
return embedding.get_embedding()
|
||||
|
||||
embedding_results = self._embeddings.embed_query(text)
|
||||
|
||||
try:
|
||||
embedding = Embedding(hash=hash)
|
||||
embedding.set_embedding(embedding_results)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
|
||||
return embedding_results
|
||||
@ -1,214 +0,0 @@
|
||||
from typing import Optional, Any, List
|
||||
|
||||
import openai
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
|
||||
_TEXT_MODE_MODEL_DICT
|
||||
from tenacity import wait_random_exponential, retry, stop_after_attempt
|
||||
|
||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
def get_embedding(
|
||||
text: str,
|
||||
engine: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> List[float]:
|
||||
"""Get embedding.
|
||||
|
||||
NOTE: Copied from OpenAI's embedding utils:
|
||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
||||
|
||||
Copied here to avoid importing unnecessary dependencies
|
||||
like matplotlib, plotly, scipy, sklearn.
|
||||
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
|
||||
float]:
|
||||
"""Asynchronously get embedding.
|
||||
|
||||
NOTE: Copied from OpenAI's embedding utils:
|
||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
||||
|
||||
Copied here to avoid importing unnecessary dependencies
|
||||
like matplotlib, plotly, scipy, sklearn.
|
||||
|
||||
"""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
|
||||
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
|
||||
"embedding"
|
||||
]
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
def get_embeddings(
|
||||
list_of_text: List[str],
|
||||
engine: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""Get embeddings.
|
||||
|
||||
NOTE: Copied from OpenAI's embedding utils:
|
||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
||||
|
||||
Copied here to avoid importing unnecessary dependencies
|
||||
like matplotlib, plotly, scipy, sklearn.
|
||||
|
||||
"""
|
||||
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
||||
|
||||
# replace newlines, which can negatively affect performance.
|
||||
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
||||
|
||||
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
|
||||
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
||||
return [d["embedding"] for d in data]
|
||||
|
||||
|
||||
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
|
||||
async def aget_embeddings(
|
||||
list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
|
||||
) -> List[List[float]]:
|
||||
"""Asynchronously get embeddings.
|
||||
|
||||
NOTE: Copied from OpenAI's embedding utils:
|
||||
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
|
||||
|
||||
Copied here to avoid importing unnecessary dependencies
|
||||
like matplotlib, plotly, scipy, sklearn.
|
||||
|
||||
"""
|
||||
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
|
||||
|
||||
# replace newlines, which can negatively affect performance.
|
||||
list_of_text = [text.replace("\n", " ") for text in list_of_text]
|
||||
|
||||
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
|
||||
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
|
||||
return [d["embedding"] for d in data]
|
||||
|
||||
|
||||
class OpenAIEmbedding(BaseEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
|
||||
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
|
||||
deployment_name: Optional[str] = None,
|
||||
openai_api_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
new_kwargs = {}
|
||||
|
||||
if 'embed_batch_size' in kwargs:
|
||||
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
|
||||
|
||||
if 'tokenizer' in kwargs:
|
||||
new_kwargs['tokenizer'] = kwargs['tokenizer']
|
||||
|
||||
super().__init__(**new_kwargs)
|
||||
self.mode = OpenAIEmbeddingMode(mode)
|
||||
self.model = OpenAIEmbeddingModelType(model)
|
||||
self.deployment_name = deployment_name
|
||||
self.openai_api_key = openai_api_key
|
||||
self.openai_api_type = kwargs.get('openai_api_type')
|
||||
self.openai_api_version = kwargs.get('openai_api_version')
|
||||
self.openai_api_base = kwargs.get('openai_api_base')
|
||||
|
||||
@handle_llm_exceptions
|
||||
def _get_query_embedding(self, query: str) -> List[float]:
|
||||
"""Get query embedding."""
|
||||
if self.deployment_name is not None:
|
||||
engine = self.deployment_name
|
||||
else:
|
||||
key = (self.mode, self.model)
|
||||
if key not in _QUERY_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _QUERY_MODE_MODEL_DICT[key]
|
||||
return get_embedding(query, engine=engine, api_key=self.openai_api_key,
|
||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
||||
api_base=self.openai_api_base)
|
||||
|
||||
def _get_text_embedding(self, text: str) -> List[float]:
|
||||
"""Get text embedding."""
|
||||
if self.deployment_name is not None:
|
||||
engine = self.deployment_name
|
||||
else:
|
||||
key = (self.mode, self.model)
|
||||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
return get_embedding(text, engine=engine, api_key=self.openai_api_key,
|
||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
||||
api_base=self.openai_api_base)
|
||||
|
||||
async def _aget_text_embedding(self, text: str) -> List[float]:
|
||||
"""Asynchronously get text embedding."""
|
||||
if self.deployment_name is not None:
|
||||
engine = self.deployment_name
|
||||
else:
|
||||
key = (self.mode, self.model)
|
||||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
|
||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
||||
api_base=self.openai_api_base)
|
||||
|
||||
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get text embeddings.
|
||||
|
||||
By default, this is a wrapper around _get_text_embedding.
|
||||
Can be overriden for batch queries.
|
||||
|
||||
"""
|
||||
if self.openai_api_type and self.openai_api_type == 'azure':
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embeddings.append(self._get_text_embedding(text))
|
||||
|
||||
return embeddings
|
||||
|
||||
if self.deployment_name is not None:
|
||||
engine = self.deployment_name
|
||||
else:
|
||||
key = (self.mode, self.model)
|
||||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
|
||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
||||
api_base=self.openai_api_base)
|
||||
return embeddings
|
||||
|
||||
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronously get text embeddings."""
|
||||
if self.openai_api_type and self.openai_api_type == 'azure':
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embeddings.append(await self._aget_text_embedding(text))
|
||||
|
||||
return embeddings
|
||||
|
||||
if self.deployment_name is not None:
|
||||
engine = self.deployment_name
|
||||
else:
|
||||
key = (self.mode, self.model)
|
||||
if key not in _TEXT_MODE_MODEL_DICT:
|
||||
raise ValueError(f"Invalid mode, model combination: {key}")
|
||||
engine = _TEXT_MODE_MODEL_DICT[key]
|
||||
embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
|
||||
api_type=self.openai_api_type, api_version=self.openai_api_version,
|
||||
api_base=self.openai_api_base)
|
||||
return embeddings
|
||||
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import List, Any
|
||||
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class BaseIndex(ABC):
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
@abstractmethod
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def text_exists(self, id: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts:
|
||||
doc_id = text.metadata['doc_id']
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata['doc_id'] for text in texts]
|
||||
@ -0,0 +1,41 @@
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
@classmethod
|
||||
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
|
||||
if indexing_technique == "high_quality":
|
||||
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
|
||||
return None
|
||||
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
**model_credentials
|
||||
))
|
||||
|
||||
return VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
elif indexing_technique == "economy":
|
||||
return KeywordTableIndex(
|
||||
dataset=dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=10
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError('Unknown indexing technique')
|
||||
@ -1,60 +0,0 @@
|
||||
from langchain.callbacks import CallbackManager
|
||||
from llama_index import ServiceContext, PromptHelper, LLMPredictor
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
@classmethod
|
||||
def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
|
||||
# set number of output tokens
|
||||
num_output = 512
|
||||
|
||||
# only for verbose
|
||||
callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
|
||||
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='text-davinci-003',
|
||||
temperature=0,
|
||||
max_tokens=num_output,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
|
||||
llm_predictor = LLMPredictor(llm=llm)
|
||||
|
||||
# These parameters here will affect the logic of segmenting the final synthesized response.
|
||||
# The number of refinement iterations in the synthesis process depends
|
||||
# on whether the length of the segmented output exceeds the max_input_size.
|
||||
prompt_helper = PromptHelper(
|
||||
max_input_size=3500,
|
||||
num_output=num_output,
|
||||
max_chunk_overlap=20
|
||||
)
|
||||
|
||||
provider = LLMBuilder.get_default_provider(tenant_id)
|
||||
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
model_provider=provider,
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
return ServiceContext.from_defaults(
|
||||
llm_predictor=llm_predictor,
|
||||
prompt_helper=prompt_helper,
|
||||
embed_model=OpenAIEmbedding(**model_credentials),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext:
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name='fake'
|
||||
)
|
||||
|
||||
return ServiceContext.from_defaults(
|
||||
llm_predictor=LLMPredictor(llm=llm),
|
||||
embed_model=OpenAIEmbedding()
|
||||
)
|
||||
@ -1,159 +0,0 @@
|
||||
import re
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Set,
|
||||
Optional
|
||||
)
|
||||
|
||||
import jieba.analyse
|
||||
|
||||
from core.index.keyword_table.stopwords import STOPWORDS
|
||||
from llama_index.indices.query.base import IS
|
||||
from llama_index import QueryMode
|
||||
from llama_index.indices.base import QueryMap
|
||||
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
|
||||
from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
|
||||
from llama_index.docstore import BaseDocumentStore
|
||||
from llama_index.indices.postprocessor.node import (
|
||||
BaseNodePostprocessor,
|
||||
)
|
||||
from llama_index.indices.response.response_builder import ResponseMode
|
||||
from llama_index.indices.service_context import ServiceContext
|
||||
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
|
||||
from llama_index.prompts.prompts import (
|
||||
QuestionAnswerPrompt,
|
||||
RefinePrompt,
|
||||
SimpleInputPrompt,
|
||||
)
|
||||
|
||||
from core.index.query.synthesizer import EnhanceResponseSynthesizer
|
||||
|
||||
|
||||
def jieba_extract_keywords(
|
||||
text_chunk: str,
|
||||
max_keywords: Optional[int] = None,
|
||||
expand_with_subtokens: bool = True,
|
||||
) -> Set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text_chunk,
|
||||
topK=max_keywords,
|
||||
)
|
||||
|
||||
if expand_with_subtokens:
|
||||
return set(expand_tokens_with_subtokens(keywords))
|
||||
else:
|
||||
return set(keywords)
|
||||
|
||||
|
||||
def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
|
||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||
results = set()
|
||||
for token in tokens:
|
||||
results.add(token)
|
||||
sub_tokens = re.findall(r"\w+", token)
|
||||
if len(sub_tokens) > 1:
|
||||
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
|
||||
"""GPT JIEBA Keyword Table Index.
|
||||
|
||||
This index uses a JIEBA keyword extractor to extract keywords from the text.
|
||||
|
||||
"""
|
||||
|
||||
def _extract_keywords(self, text: str) -> Set[str]:
|
||||
"""Extract keywords from text."""
|
||||
return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
|
||||
|
||||
@classmethod
|
||||
def get_query_map(self) -> QueryMap:
|
||||
"""Get query map."""
|
||||
super_map = super().get_query_map()
|
||||
super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
|
||||
return super_map
|
||||
|
||||
def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
|
||||
"""Delete a document."""
|
||||
# get set of ids that correspond to node
|
||||
node_idxs_to_delete = {doc_id}
|
||||
|
||||
# delete node_idxs from keyword to node idxs mapping
|
||||
keywords_to_delete = set()
|
||||
for keyword, node_idxs in self._index_struct.table.items():
|
||||
if node_idxs_to_delete.intersection(node_idxs):
|
||||
self._index_struct.table[keyword] = node_idxs.difference(
|
||||
node_idxs_to_delete
|
||||
)
|
||||
if not self._index_struct.table[keyword]:
|
||||
keywords_to_delete.add(keyword)
|
||||
|
||||
for keyword in keywords_to_delete:
|
||||
del self._index_struct.table[keyword]
|
||||
|
||||
|
||||
class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
|
||||
"""GPT Keyword Table Index JIEBA Query.
|
||||
|
||||
Extracts keywords using JIEBA keyword extractor.
|
||||
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
response = index.query("<query_str>", mode="jieba")
|
||||
|
||||
See BaseGPTKeywordTableQuery for arguments.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_args(
|
||||
cls,
|
||||
index_struct: IS,
|
||||
service_context: ServiceContext,
|
||||
docstore: Optional[BaseDocumentStore] = None,
|
||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
||||
verbose: bool = False,
|
||||
# response synthesizer args
|
||||
response_mode: ResponseMode = ResponseMode.DEFAULT,
|
||||
text_qa_template: Optional[QuestionAnswerPrompt] = None,
|
||||
refine_template: Optional[RefinePrompt] = None,
|
||||
simple_template: Optional[SimpleInputPrompt] = None,
|
||||
response_kwargs: Optional[Dict] = None,
|
||||
use_async: bool = False,
|
||||
streaming: bool = False,
|
||||
optimizer: Optional[BaseTokenUsageOptimizer] = None,
|
||||
# class-specific args
|
||||
**kwargs: Any,
|
||||
) -> "BaseGPTIndexQuery":
|
||||
response_synthesizer = EnhanceResponseSynthesizer.from_args(
|
||||
service_context=service_context,
|
||||
text_qa_template=text_qa_template,
|
||||
refine_template=refine_template,
|
||||
simple_template=simple_template,
|
||||
response_mode=response_mode,
|
||||
response_kwargs=response_kwargs,
|
||||
use_async=use_async,
|
||||
streaming=streaming,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
return cls(
|
||||
index_struct=index_struct,
|
||||
service_context=service_context,
|
||||
response_synthesizer=response_synthesizer,
|
||||
docstore=docstore,
|
||||
node_postprocessors=node_postprocessors,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_keywords(self, query_str: str) -> List[str]:
|
||||
"""Extract keywords."""
|
||||
return list(
|
||||
jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
|
||||
)
|
||||
@ -1,135 +0,0 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
|
||||
from llama_index.data_structs import KeywordTable, Node
|
||||
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
|
||||
from llama_index.indices.registry import load_index_struct_from_dict
|
||||
|
||||
from core.docstore.dataset_docstore import DatesetDocumentStore
|
||||
from core.docstore.empty_docstore import EmptyDocumentStore
|
||||
from core.index.index_builder import IndexBuilder
|
||||
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
|
||||
|
||||
|
||||
class KeywordTableIndex:
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
|
||||
def add_nodes(self, nodes: List[Node]):
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_name='fake'
|
||||
)
|
||||
|
||||
service_context = ServiceContext.from_defaults(
|
||||
llm_predictor=LLMPredictor(llm=llm),
|
||||
embed_model=OpenAIEmbedding()
|
||||
)
|
||||
|
||||
dataset_keyword_table = self.get_keyword_table()
|
||||
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
|
||||
index_struct = KeywordTable()
|
||||
else:
|
||||
index_struct_dict = dataset_keyword_table.keyword_table_dict
|
||||
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
|
||||
|
||||
# create index
|
||||
index = GPTJIEBAKeywordTableIndex(
|
||||
index_struct=index_struct,
|
||||
docstore=EmptyDocumentStore(),
|
||||
service_context=service_context
|
||||
)
|
||||
|
||||
for node in nodes:
|
||||
keywords = index._extract_keywords(node.get_text())
|
||||
self.update_segment_keywords(node.doc_id, list(keywords))
|
||||
index._index_struct.add_node(list(keywords), node)
|
||||
|
||||
index_struct_dict = index.index_struct.to_dict()
|
||||
|
||||
if not dataset_keyword_table:
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self._dataset.id,
|
||||
keyword_table=json.dumps(index_struct_dict)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
else:
|
||||
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def del_nodes(self, node_ids: List[str]):
|
||||
llm = LLMBuilder.to_llm(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_name='fake'
|
||||
)
|
||||
|
||||
service_context = ServiceContext.from_defaults(
|
||||
llm_predictor=LLMPredictor(llm=llm),
|
||||
embed_model=OpenAIEmbedding()
|
||||
)
|
||||
|
||||
dataset_keyword_table = self.get_keyword_table()
|
||||
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
|
||||
return
|
||||
else:
|
||||
index_struct_dict = dataset_keyword_table.keyword_table_dict
|
||||
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
|
||||
|
||||
# create index
|
||||
index = GPTJIEBAKeywordTableIndex(
|
||||
index_struct=index_struct,
|
||||
docstore=EmptyDocumentStore(),
|
||||
service_context=service_context
|
||||
)
|
||||
|
||||
for node_id in node_ids:
|
||||
index.delete(node_id)
|
||||
|
||||
index_struct_dict = index.index_struct.to_dict()
|
||||
|
||||
if not dataset_keyword_table:
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self._dataset.id,
|
||||
keyword_table=json.dumps(index_struct_dict)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
else:
|
||||
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@property
|
||||
def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
|
||||
docstore = DatesetDocumentStore(
|
||||
dataset=self._dataset,
|
||||
user_id=self._dataset.created_by,
|
||||
embedding_model_name="text-embedding-ada-002"
|
||||
)
|
||||
|
||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
dataset_keyword_table = self.get_keyword_table()
|
||||
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
|
||||
return None
|
||||
|
||||
index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
|
||||
|
||||
return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
|
||||
|
||||
def get_keyword_table(self):
|
||||
dataset_keyword_table = self._dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
return dataset_keyword_table
|
||||
return None
|
||||
|
||||
def update_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.commit()
|
||||
@ -0,0 +1,33 @@
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
import jieba
|
||||
from jieba.analyse import default_tfidf
|
||||
|
||||
from core.index.keyword_table_index.stopwords import STOPWORDS
|
||||
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
|
||||
def __init__(self):
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
)
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(keywords))
|
||||
|
||||
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
|
||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||
results = set()
|
||||
for token in tokens:
|
||||
results.add(token)
|
||||
sub_tokens = re.findall(r"\w+", token)
|
||||
if len(sub_tokens) > 1:
|
||||
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
|
||||
|
||||
return results
|
||||
@ -0,0 +1,238 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Any, List, Optional, Dict
|
||||
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable
|
||||
|
||||
|
||||
class KeywordTableConfig(BaseModel):
|
||||
max_keywords_per_chunk: int = 10
|
||||
|
||||
|
||||
class KeywordTableIndex(BaseIndex):
|
||||
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
|
||||
super().__init__(dataset)
|
||||
self._config = config
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = {}
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table=json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
return self
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
# get segment ids by document_id
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id,
|
||||
DocumentSegment.document_id == document_id
|
||||
).all()
|
||||
|
||||
ids = [segment.id for segment in segments]
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
return KeywordTableRetriever(index=self, **kwargs)
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
|
||||
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id,
|
||||
DocumentSegment.index_node_id == chunk_index
|
||||
).first()
|
||||
|
||||
if segment:
|
||||
documents.append(Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": chunk_index,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
))
|
||||
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
keyword_table_dict = {
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": keyword_table
|
||||
}
|
||||
}
|
||||
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
|
||||
db.session.commit()
|
||||
|
||||
def _get_dataset_keyword_table(self) -> Optional[dict]:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
if dataset_keyword_table.keyword_table_dict:
|
||||
return dataset_keyword_table.keyword_table_dict['__data__']['table']
|
||||
else:
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table=json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
return {}
|
||||
|
||||
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
|
||||
for keyword in keywords:
|
||||
if keyword not in keyword_table:
|
||||
keyword_table[keyword] = set()
|
||||
keyword_table[keyword].add(id)
|
||||
return keyword_table
|
||||
|
||||
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
|
||||
# get set of ids that correspond to node
|
||||
node_idxs_to_delete = set(ids)
|
||||
|
||||
# delete node_idxs from keyword to node idxs mapping
|
||||
keywords_to_delete = set()
|
||||
for keyword, node_idxs in keyword_table.items():
|
||||
if node_idxs_to_delete.intersection(node_idxs):
|
||||
keyword_table[keyword] = node_idxs.difference(
|
||||
node_idxs_to_delete
|
||||
)
|
||||
if not keyword_table[keyword]:
|
||||
keywords_to_delete.add(keyword)
|
||||
|
||||
for keyword in keywords_to_delete:
|
||||
del keyword_table[keyword]
|
||||
|
||||
return keyword_table
|
||||
|
||||
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keywords = keyword_table_handler.extract_keywords(query)
|
||||
|
||||
# go through text chunks in order of most matching keywords
|
||||
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
||||
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
||||
for keyword in keywords:
|
||||
for node_id in keyword_table[keyword]:
|
||||
chunk_indices_count[node_id] += 1
|
||||
|
||||
sorted_chunk_indices = sorted(
|
||||
list(chunk_indices_count.keys()),
|
||||
key=lambda x: chunk_indices_count[x],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return sorted_chunk_indices[: k]
|
||||
|
||||
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.commit()
|
||||
|
||||
|
||||
class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||
index: KeywordTableIndex
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
return self.index.search(query, **self.search_kwargs)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("KeywordTableRetriever does not support async")
|
||||
|
||||
|
||||
class SetEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
return super().default(obj)
|
||||
@ -1,79 +0,0 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Optional, Sequence,
|
||||
)
|
||||
|
||||
from llama_index.indices.response.response_synthesis import ResponseSynthesizer
|
||||
from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
|
||||
from llama_index.indices.service_context import ServiceContext
|
||||
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
|
||||
from llama_index.prompts.prompts import (
|
||||
QuestionAnswerPrompt,
|
||||
RefinePrompt,
|
||||
SimpleInputPrompt,
|
||||
)
|
||||
from llama_index.types import RESPONSE_TEXT_TYPE
|
||||
|
||||
|
||||
class EnhanceResponseSynthesizer(ResponseSynthesizer):
|
||||
@classmethod
|
||||
def from_args(
|
||||
cls,
|
||||
service_context: ServiceContext,
|
||||
streaming: bool = False,
|
||||
use_async: bool = False,
|
||||
text_qa_template: Optional[QuestionAnswerPrompt] = None,
|
||||
refine_template: Optional[RefinePrompt] = None,
|
||||
simple_template: Optional[SimpleInputPrompt] = None,
|
||||
response_mode: ResponseMode = ResponseMode.DEFAULT,
|
||||
response_kwargs: Optional[Dict] = None,
|
||||
optimizer: Optional[BaseTokenUsageOptimizer] = None,
|
||||
) -> "ResponseSynthesizer":
|
||||
response_builder: Optional[BaseResponseBuilder] = None
|
||||
if response_mode != ResponseMode.NO_TEXT:
|
||||
if response_mode == 'no_synthesizer':
|
||||
response_builder = NoSynthesizer(
|
||||
service_context=service_context,
|
||||
simple_template=simple_template,
|
||||
streaming=streaming,
|
||||
)
|
||||
else:
|
||||
response_builder = get_response_builder(
|
||||
service_context,
|
||||
text_qa_template,
|
||||
refine_template,
|
||||
simple_template,
|
||||
response_mode,
|
||||
use_async=use_async,
|
||||
streaming=streaming,
|
||||
)
|
||||
return cls(response_builder, response_mode, response_kwargs, optimizer)
|
||||
|
||||
|
||||
class NoSynthesizer(BaseResponseBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
service_context: ServiceContext,
|
||||
simple_template: Optional[SimpleInputPrompt] = None,
|
||||
streaming: bool = False,
|
||||
) -> None:
|
||||
super().__init__(service_context, streaming)
|
||||
|
||||
async def aget_response(
|
||||
self,
|
||||
query_str: str,
|
||||
text_chunks: Sequence[str],
|
||||
prev_response: Optional[str] = None,
|
||||
**response_kwargs: Any,
|
||||
) -> RESPONSE_TEXT_TYPE:
|
||||
return "\n".join(text_chunks)
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
query_str: str,
|
||||
text_chunks: Sequence[str],
|
||||
prev_response: Optional[str] = None,
|
||||
**response_kwargs: Any,
|
||||
) -> RESPONSE_TEXT_TYPE:
|
||||
return "\n".join(text_chunks)
|
||||
@ -1,22 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from llama_index.readers.file.base_parser import BaseParser
|
||||
|
||||
|
||||
class HTMLParser(BaseParser):
|
||||
"""HTML parser."""
|
||||
|
||||
def _init_parser(self) -> Dict:
|
||||
"""Init parser."""
|
||||
return {}
|
||||
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
||||
"""Parse file."""
|
||||
with open(file, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, 'html.parser')
|
||||
text = soup.get_text()
|
||||
text = text.strip() if text else ''
|
||||
|
||||
return text
|
||||
@ -1,111 +0,0 @@
|
||||
"""Markdown parser.
|
||||
|
||||
Contains parser for md files.
|
||||
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from llama_index.readers.file.base_parser import BaseParser
|
||||
|
||||
|
||||
class MarkdownParser(BaseParser):
|
||||
"""Markdown parser.
|
||||
|
||||
Extract text from markdown files.
|
||||
Returns dictionary with keys as headers and values as the text between headers.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
remove_hyperlinks: bool = True,
|
||||
remove_images: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._remove_hyperlinks = remove_hyperlinks
|
||||
self._remove_images = remove_images
|
||||
|
||||
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
|
||||
"""Convert a markdown file to a dictionary.
|
||||
|
||||
The keys are the headers and the values are the text under each header.
|
||||
|
||||
"""
|
||||
markdown_tups: List[Tuple[Optional[str], str]] = []
|
||||
lines = markdown_text.split("\n")
|
||||
|
||||
current_header = None
|
||||
current_text = ""
|
||||
|
||||
for line in lines:
|
||||
header_match = re.match(r"^#+\s", line)
|
||||
if header_match:
|
||||
if current_header is not None:
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
current_header = line
|
||||
current_text = ""
|
||||
else:
|
||||
current_text += line + "\n"
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
if current_header is not None:
|
||||
# pass linting, assert keys are defined
|
||||
markdown_tups = [
|
||||
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
|
||||
for key, value in markdown_tups
|
||||
]
|
||||
else:
|
||||
markdown_tups = [
|
||||
(key, re.sub("\n", "", value)) for key, value in markdown_tups
|
||||
]
|
||||
|
||||
return markdown_tups
|
||||
|
||||
def remove_images(self, content: str) -> str:
|
||||
"""Get a dictionary of a markdown file from its path."""
|
||||
pattern = r"!{1}\[\[(.*)\]\]"
|
||||
content = re.sub(pattern, "", content)
|
||||
return content
|
||||
|
||||
def remove_hyperlinks(self, content: str) -> str:
|
||||
"""Get a dictionary of a markdown file from its path."""
|
||||
pattern = r"\[(.*?)\]\((.*?)\)"
|
||||
content = re.sub(pattern, r"\1", content)
|
||||
return content
|
||||
|
||||
def _init_parser(self) -> Dict:
|
||||
"""Initialize the parser with the config."""
|
||||
return {}
|
||||
|
||||
def parse_tups(
|
||||
self, filepath: Path, errors: str = "ignore"
|
||||
) -> List[Tuple[Optional[str], str]]:
|
||||
"""Parse file into tuples."""
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
if self._remove_hyperlinks:
|
||||
content = self.remove_hyperlinks(content)
|
||||
if self._remove_images:
|
||||
content = self.remove_images(content)
|
||||
markdown_tups = self.markdown_to_tups(content)
|
||||
return markdown_tups
|
||||
|
||||
def parse_file(
|
||||
self, filepath: Path, errors: str = "ignore"
|
||||
) -> Union[str, List[str]]:
|
||||
"""Parse file into string."""
|
||||
tups = self.parse_tups(filepath, errors=errors)
|
||||
results = []
|
||||
# TODO: don't include headers right now
|
||||
for header, value in tups:
|
||||
if header is None:
|
||||
results.append(value)
|
||||
else:
|
||||
results.append(f"\n\n{header}\n{value}")
|
||||
return results
|
||||
@ -1,56 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from flask import current_app
|
||||
from llama_index.readers.file.base_parser import BaseParser
|
||||
from pypdf import PdfReader
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class PDFParser(BaseParser):
|
||||
"""PDF parser."""
|
||||
|
||||
def _init_parser(self) -> Dict:
|
||||
"""Init parser."""
|
||||
return {}
|
||||
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
||||
"""Parse file."""
|
||||
if not current_app.config.get('PDF_PREVIEW', True):
|
||||
return ''
|
||||
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
|
||||
upload_file: UploadFile = self._parser_config['upload_file']
|
||||
if upload_file.hash:
|
||||
plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
|
||||
try:
|
||||
text = storage.load(plaintext_file_key).decode('utf-8')
|
||||
plaintext_file_exists = True
|
||||
return text
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
text_list = []
|
||||
with open(file, "rb") as fp:
|
||||
# Create a PDF object
|
||||
pdf = PdfReader(fp)
|
||||
|
||||
# Get the number of pages in the PDF document
|
||||
num_pages = len(pdf.pages)
|
||||
|
||||
# Iterate over every page
|
||||
for page in range(num_pages):
|
||||
# Extract the text from the page
|
||||
page_text = pdf.pages[page].extract_text()
|
||||
text_list.append(page_text)
|
||||
text = "\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode('utf-8'))
|
||||
|
||||
return text
|
||||
@ -1,33 +0,0 @@
|
||||
from pathlib import Path
|
||||
import json
|
||||
from typing import Dict
|
||||
from openpyxl import load_workbook
|
||||
|
||||
from llama_index.readers.file.base_parser import BaseParser
|
||||
from flask import current_app
|
||||
|
||||
|
||||
class XLSXParser(BaseParser):
|
||||
"""XLSX parser."""
|
||||
|
||||
def _init_parser(self) -> Dict:
|
||||
"""Init parser"""
|
||||
return {}
|
||||
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
||||
data = []
|
||||
keys = []
|
||||
with open(file, "r") as fp:
|
||||
wb = load_workbook(filename=file, read_only=True)
|
||||
# loop over all sheets
|
||||
for sheet in wb:
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
if all(v is None for v in row):
|
||||
continue
|
||||
if keys == []:
|
||||
keys = list(map(str, row))
|
||||
else:
|
||||
row_dict = dict(zip(keys, row))
|
||||
row_dict = {k: v for k, v in row_dict.items() if v}
|
||||
data.append(json.dumps(row_dict, ensure_ascii=False))
|
||||
return '\n\n'.join(data)
|
||||
@ -1,136 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_index.data_structs import Node
|
||||
from requests import ReadTimeout
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from tenacity import retry, stop_after_attempt, retry_if_exception_type
|
||||
|
||||
from core.index.index_builder import IndexBuilder
|
||||
from core.vector_store.base import BaseGPTVectorStoreIndex
|
||||
from extensions.ext_vector_store import vector_store
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Embedding
|
||||
|
||||
|
||||
class VectorIndex:
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
|
||||
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
|
||||
if not self._dataset.index_struct_dict:
|
||||
index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
|
||||
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
|
||||
db.session.commit()
|
||||
|
||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
index = vector_store.get_index(
|
||||
service_context=service_context,
|
||||
index_struct=self._dataset.index_struct_dict
|
||||
)
|
||||
|
||||
if duplicate_check:
|
||||
nodes = self._filter_duplicate_nodes(index, nodes)
|
||||
|
||||
embedding_queue_nodes = []
|
||||
embedded_nodes = []
|
||||
for node in nodes:
|
||||
node_hash = node.doc_hash
|
||||
|
||||
# if node hash in cached embedding tables, use cached embedding
|
||||
embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
|
||||
if embedding:
|
||||
node.embedding = embedding.get_embedding()
|
||||
embedded_nodes.append(node)
|
||||
else:
|
||||
embedding_queue_nodes.append(node)
|
||||
|
||||
if embedding_queue_nodes:
|
||||
embedding_results = index._get_node_embedding_results(
|
||||
embedding_queue_nodes,
|
||||
set(),
|
||||
)
|
||||
|
||||
# pre embed nodes for cached embedding
|
||||
for embedding_result in embedding_results:
|
||||
node = embedding_result.node
|
||||
node.embedding = embedding_result.embedding
|
||||
|
||||
try:
|
||||
embedding = Embedding(hash=node.doc_hash)
|
||||
embedding.set_embedding(node.embedding)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
continue
|
||||
except:
|
||||
logging.exception('Failed to add embedding to db')
|
||||
continue
|
||||
|
||||
embedded_nodes.append(node)
|
||||
|
||||
self.index_insert_nodes(index, embedded_nodes)
|
||||
|
||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
||||
def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
|
||||
index.insert_nodes(nodes)
|
||||
|
||||
def del_nodes(self, node_ids: List[str]):
|
||||
if not self._dataset.index_struct_dict:
|
||||
return
|
||||
|
||||
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
index = vector_store.get_index(
|
||||
service_context=service_context,
|
||||
index_struct=self._dataset.index_struct_dict
|
||||
)
|
||||
|
||||
for node_id in node_ids:
|
||||
self.index_delete_node(index, node_id)
|
||||
|
||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
||||
def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
|
||||
index.delete_node(node_id)
|
||||
|
||||
def del_doc(self, doc_id: str):
|
||||
if not self._dataset.index_struct_dict:
|
||||
return
|
||||
|
||||
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
index = vector_store.get_index(
|
||||
service_context=service_context,
|
||||
index_struct=self._dataset.index_struct_dict
|
||||
)
|
||||
|
||||
self.index_delete_doc(index, doc_id)
|
||||
|
||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
||||
def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
|
||||
index.delete(doc_id)
|
||||
|
||||
@property
|
||||
def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
|
||||
if not self._dataset.index_struct_dict:
|
||||
return None
|
||||
|
||||
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
|
||||
|
||||
return vector_store.get_index(
|
||||
service_context=service_context,
|
||||
index_struct=self._dataset.index_struct_dict
|
||||
)
|
||||
|
||||
def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
|
||||
for node in nodes:
|
||||
node_id = node.doc_id
|
||||
exists_duplicate_node = index.exists_by_node_id(node_id)
|
||||
if exists_duplicate_node:
|
||||
nodes.remove(node)
|
||||
|
||||
return nodes
|
||||
@ -0,0 +1,175 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import List, Any, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore
|
||||
from weaviate import UnexpectedStatusCodeException
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
class BaseVectorIndex(BaseIndex):
|
||||
|
||||
def __init__(self, dataset: Dataset, embeddings: Embeddings):
|
||||
super().__init__(dataset)
|
||||
self._embeddings = embeddings
|
||||
self._vector_store = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_index_struct(self) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_vector_store_class(self) -> type:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
|
||||
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||
|
||||
if search_type == 'similarity_score_threshold':
|
||||
score_threshold = search_kwargs.get("score_threshold")
|
||||
if (score_threshold is None) or (not isinstance(score_threshold, float)):
|
||||
search_kwargs['score_threshold'] = .0
|
||||
|
||||
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
|
||||
query, **search_kwargs
|
||||
)
|
||||
|
||||
docs = []
|
||||
for doc, similarity in docs_with_similarity:
|
||||
doc.metadata['score'] = similarity
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
# similarity k
|
||||
# mmr k, fetch_k, lambda_mult
|
||||
# similarity_score_threshold k
|
||||
return vector_store.as_retriever(
|
||||
search_type=search_type,
|
||||
search_kwargs=search_kwargs
|
||||
).get_relevant_documents(query)
|
||||
|
||||
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return vector_store.as_retriever(**kwargs)
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
if kwargs.get('duplicate_check', False):
|
||||
texts = self._filter_duplicate_texts(texts)
|
||||
|
||||
uuids = self._get_uuids(texts)
|
||||
vector_store.add_documents(texts, uuids=uuids)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return vector_store.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
for node_id in ids:
|
||||
vector_store.del_text(node_id)
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
return False
|
||||
|
||||
def recreate_dataset(self, dataset: Dataset):
|
||||
logging.info(f"Recreating dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
self.delete()
|
||||
except UnexpectedStatusCodeException as e:
|
||||
if e.status_code != 400:
|
||||
# 400 means index not exists
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
origin_index_struct = self.dataset.index_struct
|
||||
self.dataset.index_struct = None
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create(documents)
|
||||
except Exception as e:
|
||||
self.dataset.index_struct = origin_index_struct
|
||||
raise e
|
||||
|
||||
dataset.index_struct = json.dumps(self.to_index_struct())
|
||||
|
||||
db.session.commit()
|
||||
|
||||
self.dataset = dataset
|
||||
logging.info(f"Dataset {dataset.id} recreate successfully.")
|
||||
@ -0,0 +1,116 @@
|
||||
import os
|
||||
from typing import Optional, Any, List, cast
|
||||
|
||||
import qdrant_client
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.qdrant_vector_store import QdrantVectorStore
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
root_path: Optional[str]
|
||||
|
||||
def to_qdrant_params(self):
|
||||
if self.endpoint and self.endpoint.startswith('path:'):
|
||||
path = self.endpoint.replace('path:', '')
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {
|
||||
'path': path
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
}
|
||||
|
||||
|
||||
class QdrantVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client_config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'qdrant'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
return self.dataset.index_struct_dict['vector_store']['collection_name']
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Index_" + dataset_id.replace("-", "_")
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = QdrantVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
ids=uuids,
|
||||
content_payload_key='text',
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
client = qdrant_client.QdrantClient(
|
||||
**self._client_config.to_qdrant_params()
|
||||
)
|
||||
|
||||
return QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embeddings=self._embeddings,
|
||||
content_payload_key='text'
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return QdrantVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.document_id",
|
||||
match=models.MatchValue(value=document_id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
|
||||
if class_prefix.startswith('Vector_'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -0,0 +1,69 @@
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document
|
||||
|
||||
|
||||
class VectorIndex:
|
||||
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
|
||||
self._dataset = dataset
|
||||
self._embeddings = embeddings
|
||||
self._vector_index = self._init_vector_index(dataset, config, embeddings)
|
||||
|
||||
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
|
||||
if not vector_type:
|
||||
raise ValueError(f"Vector store must be specified.")
|
||||
|
||||
if vector_type == "weaviate":
|
||||
from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig
|
||||
|
||||
return WeaviateVectorIndex(
|
||||
dataset=dataset,
|
||||
config=WeaviateConfig(
|
||||
endpoint=config.get('WEAVIATE_ENDPOINT'),
|
||||
api_key=config.get('WEAVIATE_API_KEY'),
|
||||
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
elif vector_type == "qdrant":
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
return QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
if not self._dataset.index_struct_dict:
|
||||
self._vector_index.create(texts, **kwargs)
|
||||
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
self._vector_index.add_texts(texts, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._vector_index is not None:
|
||||
method = getattr(self._vector_index, name)
|
||||
if callable(method):
|
||||
return method
|
||||
|
||||
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
|
||||
|
||||
@ -0,0 +1,132 @@
|
||||
from typing import Optional, cast
|
||||
|
||||
import weaviate
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document, BaseRetriever
|
||||
from langchain.vectorstores import VectorStore
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class WeaviateVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||
|
||||
weaviate.connect.connection.has_grpc = False
|
||||
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 60),
|
||||
startup_period=None
|
||||
)
|
||||
|
||||
client.batch.configure(
|
||||
# `batch_size` takes an `int` value to enable auto-batching
|
||||
# (`None` is used for manual batching)
|
||||
batch_size=config.batch_size,
|
||||
# dynamically update the `batch_size` based on import speed
|
||||
dynamic=True,
|
||||
# `timeout_retries` takes an `int` value to retry on time outs
|
||||
timeout_retries=3,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'weaviate'
|
||||
|
||||
def get_index_name(self, dataset: Dataset) -> str:
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
class_prefix += '_Node'
|
||||
|
||||
return class_prefix
|
||||
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
||||
}
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _get_vector_store(self) -> VectorStore:
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
|
||||
return WeaviateVectorStore(
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
text_key='text',
|
||||
embedding=self._embeddings,
|
||||
attributes=attributes,
|
||||
by_text=False
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return WeaviateVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": ["document_id"],
|
||||
"valueText": document_id
|
||||
})
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -0,0 +1,87 @@
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.llm.llm_builder import LLMBuilder
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class DatasetTool(BaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
|
||||
dataset: Dataset
|
||||
k: int = 2
|
||||
|
||||
def _run(self, tool_input: str) -> str:
|
||||
if self.dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
kw_table_index = KeywordTableIndex(
|
||||
dataset=self.dataset,
|
||||
config=KeywordTableConfig(
|
||||
max_keywords_per_chunk=5
|
||||
)
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
|
||||
else:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
**model_credentials
|
||||
))
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=self.dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = vector_index.search(
|
||||
tool_input,
|
||||
search_type='similarity',
|
||||
search_kwargs={
|
||||
'k': self.k
|
||||
}
|
||||
)
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
||||
hit_callback.on_tool_end(documents)
|
||||
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
model_credentials = LLMBuilder.get_model_credentials(
|
||||
tenant_id=self.dataset.tenant_id,
|
||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
||||
model_name='text-embedding-ada-002'
|
||||
)
|
||||
|
||||
embeddings = CacheEmbedding(OpenAIEmbeddings(
|
||||
**model_credentials
|
||||
))
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=self.dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
documents = await vector_index.asearch(
|
||||
tool_input,
|
||||
search_type='similarity',
|
||||
search_kwargs={
|
||||
'k': 10
|
||||
}
|
||||
)
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
|
||||
hit_callback.on_tool_end(documents)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
@ -1,73 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
from llama_index.langchain_helpers.agents import IndexToolConfig
|
||||
|
||||
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.index.keyword_table_index import KeywordTableIndex
|
||||
from core.index.vector_index import VectorIndex
|
||||
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
|
||||
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class DatasetToolBuilder:
|
||||
@classmethod
|
||||
def build_dataset_tool(cls, dataset: Dataset,
|
||||
response_mode: str = "no_synthesizer",
|
||||
callback_handler: Optional[DatasetToolCallbackHandler] = None):
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
index = KeywordTableIndex(dataset=dataset).query_index
|
||||
|
||||
if not index:
|
||||
return None
|
||||
|
||||
query_kwargs = {
|
||||
"mode": "default",
|
||||
"response_mode": response_mode,
|
||||
"query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE,
|
||||
"max_keywords_per_query": 5,
|
||||
# If num_chunks_per_query is too large,
|
||||
# it will slow down the synthesis process due to multiple iterations of refinement.
|
||||
"num_chunks_per_query": 2
|
||||
}
|
||||
else:
|
||||
index = VectorIndex(dataset=dataset).query_index
|
||||
|
||||
if not index:
|
||||
return None
|
||||
|
||||
query_kwargs = {
|
||||
"mode": "default",
|
||||
"response_mode": response_mode,
|
||||
# If top_k is too large,
|
||||
# it will slow down the synthesis process due to multiple iterations of refinement.
|
||||
"similarity_top_k": 2
|
||||
}
|
||||
|
||||
# fulfill description when it is empty
|
||||
description = dataset.description
|
||||
if not description:
|
||||
description = 'useful for when you want to answer queries about the ' + dataset.name
|
||||
|
||||
index_tool_config = IndexToolConfig(
|
||||
index=index,
|
||||
name=f"dataset-{dataset.id}",
|
||||
description=description,
|
||||
index_query_kwargs=query_kwargs,
|
||||
tool_kwargs={
|
||||
"callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()])
|
||||
},
|
||||
# tool_kwargs={"return_direct": True},
|
||||
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
|
||||
)
|
||||
|
||||
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id)
|
||||
|
||||
return EnhanceLlamaIndexTool.from_tool_config(
|
||||
tool_config=index_tool_config,
|
||||
callback_handler=index_callback_handler
|
||||
)
|
||||
@ -1,43 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from llama_index.indices.base import BaseGPTIndex
|
||||
from llama_index.langchain_helpers.agents import IndexToolConfig
|
||||
from pydantic import Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
|
||||
|
||||
|
||||
class EnhanceLlamaIndexTool(BaseTool):
|
||||
"""Tool for querying a LlamaIndex."""
|
||||
|
||||
# NOTE: name/description still needs to be set
|
||||
index: BaseGPTIndex
|
||||
query_kwargs: Dict = Field(default_factory=dict)
|
||||
return_sources: bool = False
|
||||
callback_handler: IndexToolCallbackHandler
|
||||
|
||||
@classmethod
|
||||
def from_tool_config(cls, tool_config: IndexToolConfig,
|
||||
callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
|
||||
"""Create a tool from a tool config."""
|
||||
return_sources = tool_config.tool_kwargs.pop("return_sources", False)
|
||||
return cls(
|
||||
index=tool_config.index,
|
||||
callback_handler=callback_handler,
|
||||
name=tool_config.name,
|
||||
description=tool_config.description,
|
||||
return_sources=return_sources,
|
||||
query_kwargs=tool_config.index_query_kwargs,
|
||||
**tool_config.tool_kwargs,
|
||||
)
|
||||
|
||||
def _run(self, tool_input: str) -> str:
|
||||
response = self.index.query(tool_input, **self.query_kwargs)
|
||||
self.callback_handler.on_tool_end(response)
|
||||
return str(response)
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
response = await self.index.aquery(tool_input, **self.query_kwargs)
|
||||
self.callback_handler.on_tool_end(response)
|
||||
return str(response)
|
||||
@ -1,34 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from llama_index import ServiceContext, GPTVectorStoreIndex
|
||||
from llama_index.data_structs import Node
|
||||
from llama_index.vector_stores.types import VectorStore
|
||||
|
||||
|
||||
class BaseVectorStoreClient(ABC):
|
||||
@abstractmethod
|
||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_index_config(self, index_id: str) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
|
||||
def delete_node(self, node_id: str):
|
||||
self._vector_store.delete_node(node_id)
|
||||
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
return self._vector_store.exists_by_node_id(node_id)
|
||||
|
||||
|
||||
class EnhanceVectorStore(ABC):
|
||||
@abstractmethod
|
||||
def delete_node(self, node_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
pass
|
||||
@ -0,0 +1,69 @@
|
||||
from typing import cast, Any
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Qdrant
|
||||
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
|
||||
class QdrantVectorStore(Qdrant):
|
||||
def del_texts(self, filter: Filter):
|
||||
if not filter:
|
||||
raise ValueError('filter must not be empty')
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
|
||||
def del_text(self, uuid: str) -> None:
|
||||
self._reload_if_needed()
|
||||
|
||||
self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=PointIdsList(
|
||||
points=[uuid],
|
||||
),
|
||||
)
|
||||
|
||||
def text_exists(self, uuid: str) -> bool:
|
||||
self._reload_if_needed()
|
||||
|
||||
response = self.client.retrieve(
|
||||
collection_name=self.collection_name,
|
||||
ids=[uuid]
|
||||
)
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def delete(self):
|
||||
self._reload_if_needed()
|
||||
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
@classmethod
|
||||
def _document_from_scored_point(
|
||||
cls,
|
||||
scored_point: Any,
|
||||
content_payload_key: str,
|
||||
metadata_payload_key: str,
|
||||
) -> Document:
|
||||
if scored_point.payload.get('doc_id'):
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
metadata={'doc_id': scored_point.id}
|
||||
)
|
||||
|
||||
return Document(
|
||||
page_content=scored_point.payload.get(content_payload_key),
|
||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||
)
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self.client, QdrantLocal):
|
||||
self.client = cast(QdrantLocal, self.client)
|
||||
self.client._load()
|
||||
@ -1,147 +0,0 @@
|
||||
import os
|
||||
from typing import cast, List
|
||||
|
||||
from llama_index.data_structs import Node
|
||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
||||
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
|
||||
from qdrant_client.http.models import Payload, Filter
|
||||
|
||||
import qdrant_client
|
||||
from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
|
||||
from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
|
||||
from llama_index.vector_stores import QdrantVectorStore
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
|
||||
|
||||
|
||||
class QdrantVectorStoreClient(BaseVectorStoreClient):
|
||||
|
||||
def __init__(self, url: str, api_key: str, root_path: str):
|
||||
self._client = self.init_from_config(url, api_key, root_path)
|
||||
|
||||
@classmethod
|
||||
def init_from_config(cls, url: str, api_key: str, root_path: str):
|
||||
if url and url.startswith('path:'):
|
||||
path = url.replace('path:', '')
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(root_path, path)
|
||||
|
||||
return qdrant_client.QdrantClient(
|
||||
path=path
|
||||
)
|
||||
else:
|
||||
return qdrant_client.QdrantClient(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
||||
index_struct = QdrantIndexDict()
|
||||
|
||||
if self._client is None:
|
||||
raise Exception("Vector client is not initialized.")
|
||||
|
||||
# {"collection_name": "Gpt_index_xxx"}
|
||||
collection_name = config.get('collection_name')
|
||||
if not collection_name:
|
||||
raise Exception("collection_name cannot be None.")
|
||||
|
||||
return GPTQdrantEnhanceIndex(
|
||||
service_context=service_context,
|
||||
index_struct=index_struct,
|
||||
vector_store=QdrantEnhanceVectorStore(
|
||||
client=self._client,
|
||||
collection_name=collection_name
|
||||
)
|
||||
)
|
||||
|
||||
def to_index_config(self, index_id: str) -> dict:
|
||||
return {"collection_name": index_id}
|
||||
|
||||
|
||||
class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
|
||||
pass
|
||||
|
||||
|
||||
class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
|
||||
def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete node from the index.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=rest.Filter(
|
||||
must=[
|
||||
rest.FieldCondition(
|
||||
key="id", match=rest.MatchValue(value=node_id)
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
"""
|
||||
Get node from the index by node id.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
self._reload_if_needed()
|
||||
|
||||
response = self._client.retrieve(
|
||||
collection_name=self._collection_name,
|
||||
ids=[node_id]
|
||||
)
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: VectorStoreQuery,
|
||||
) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes.
|
||||
|
||||
Args:
|
||||
query (VectorStoreQuery): query
|
||||
"""
|
||||
query_embedding = cast(List[float], query.query_embedding)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
response = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_embedding,
|
||||
limit=cast(int, query.similarity_top_k),
|
||||
query_filter=cast(Filter, self._build_query_filter(query)),
|
||||
with_vectors=True
|
||||
)
|
||||
|
||||
nodes = []
|
||||
similarities = []
|
||||
ids = []
|
||||
for point in response:
|
||||
payload = cast(Payload, point.payload)
|
||||
node = Node(
|
||||
doc_id=str(point.id),
|
||||
text=payload.get("text"),
|
||||
embedding=point.vector,
|
||||
extra_info=payload.get("extra_info"),
|
||||
relationships={
|
||||
DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
|
||||
},
|
||||
)
|
||||
nodes.append(node)
|
||||
similarities.append(point.score)
|
||||
ids.append(str(point.id))
|
||||
|
||||
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client._client, QdrantLocal):
|
||||
self._client._client._load()
|
||||
@ -1,62 +0,0 @@
|
||||
from flask import Flask
|
||||
from llama_index import ServiceContext, GPTVectorStoreIndex
|
||||
from requests import ReadTimeout
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
||||
|
||||
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
|
||||
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
|
||||
|
||||
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
|
||||
|
||||
|
||||
class VectorStore:
|
||||
|
||||
def __init__(self):
|
||||
self._vector_store = None
|
||||
self._client = None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
if not app.config['VECTOR_STORE']:
|
||||
return
|
||||
|
||||
self._vector_store = app.config['VECTOR_STORE']
|
||||
if self._vector_store not in SUPPORTED_VECTOR_STORES:
|
||||
raise ValueError(f"Vector store {self._vector_store} is not supported.")
|
||||
|
||||
if self._vector_store == 'weaviate':
|
||||
self._client = WeaviateVectorStoreClient(
|
||||
endpoint=app.config['WEAVIATE_ENDPOINT'],
|
||||
api_key=app.config['WEAVIATE_API_KEY'],
|
||||
grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'],
|
||||
batch_size=app.config['WEAVIATE_BATCH_SIZE']
|
||||
)
|
||||
elif self._vector_store == 'qdrant':
|
||||
self._client = QdrantVectorStoreClient(
|
||||
url=app.config['QDRANT_URL'],
|
||||
api_key=app.config['QDRANT_API_KEY'],
|
||||
root_path=app.root_path
|
||||
)
|
||||
|
||||
app.extensions['vector_store'] = self
|
||||
|
||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
||||
def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
|
||||
vector_store_config: dict = index_struct.get('vector_store')
|
||||
index = self.get_client().get_index(
|
||||
service_context=service_context,
|
||||
config=vector_store_config
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
def to_index_struct(self, index_id: str) -> dict:
|
||||
return {
|
||||
"type": self._vector_store,
|
||||
"vector_store": self.get_client().to_index_config(index_id)
|
||||
}
|
||||
|
||||
def get_client(self):
|
||||
if not self._client:
|
||||
raise Exception("Vector store client is not initialized.")
|
||||
|
||||
return self._client
|
||||
@ -1,66 +0,0 @@
|
||||
from llama_index.indices.query.base import IS
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional
|
||||
)
|
||||
|
||||
from llama_index.docstore import BaseDocumentStore
|
||||
from llama_index.indices.postprocessor.node import (
|
||||
BaseNodePostprocessor,
|
||||
)
|
||||
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
|
||||
from llama_index.indices.response.response_builder import ResponseMode
|
||||
from llama_index.indices.service_context import ServiceContext
|
||||
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
|
||||
from llama_index.prompts.prompts import (
|
||||
QuestionAnswerPrompt,
|
||||
RefinePrompt,
|
||||
SimpleInputPrompt,
|
||||
)
|
||||
|
||||
from core.index.query.synthesizer import EnhanceResponseSynthesizer
|
||||
|
||||
|
||||
class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
|
||||
@classmethod
|
||||
def from_args(
|
||||
cls,
|
||||
index_struct: IS,
|
||||
service_context: ServiceContext,
|
||||
docstore: Optional[BaseDocumentStore] = None,
|
||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
||||
verbose: bool = False,
|
||||
# response synthesizer args
|
||||
response_mode: ResponseMode = ResponseMode.DEFAULT,
|
||||
text_qa_template: Optional[QuestionAnswerPrompt] = None,
|
||||
refine_template: Optional[RefinePrompt] = None,
|
||||
simple_template: Optional[SimpleInputPrompt] = None,
|
||||
response_kwargs: Optional[Dict] = None,
|
||||
use_async: bool = False,
|
||||
streaming: bool = False,
|
||||
optimizer: Optional[BaseTokenUsageOptimizer] = None,
|
||||
# class-specific args
|
||||
**kwargs: Any,
|
||||
) -> "BaseGPTIndexQuery":
|
||||
response_synthesizer = EnhanceResponseSynthesizer.from_args(
|
||||
service_context=service_context,
|
||||
text_qa_template=text_qa_template,
|
||||
refine_template=refine_template,
|
||||
simple_template=simple_template,
|
||||
response_mode=response_mode,
|
||||
response_kwargs=response_kwargs,
|
||||
use_async=use_async,
|
||||
streaming=streaming,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
return cls(
|
||||
index_struct=index_struct,
|
||||
service_context=service_context,
|
||||
response_synthesizer=response_synthesizer,
|
||||
docstore=docstore,
|
||||
node_postprocessors=node_postprocessors,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
@ -0,0 +1,38 @@
|
||||
from langchain.vectorstores import Weaviate
|
||||
|
||||
|
||||
class WeaviateVectorStore(Weaviate):
|
||||
def del_texts(self, where_filter: dict):
|
||||
if not where_filter:
|
||||
raise ValueError('where_filter must not be empty')
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._index_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
|
||||
def del_text(self, uuid: str) -> None:
|
||||
self._client.data_object.delete(
|
||||
uuid,
|
||||
class_name=self._index_name
|
||||
)
|
||||
|
||||
def text_exists(self, uuid: str) -> bool:
|
||||
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": uuid,
|
||||
}).with_limit(1).do()
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
entries = result["data"]["Get"][self._index_name]
|
||||
if len(entries) == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def delete(self):
|
||||
self._client.schema.delete_class(self._index_name)
|
||||
@ -1,270 +0,0 @@
|
||||
import json
|
||||
import weaviate
|
||||
from dataclasses import field
|
||||
from typing import List, Any, Dict, Optional
|
||||
|
||||
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
|
||||
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
|
||||
from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
|
||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
||||
from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
|
||||
from llama_index.vector_stores import WeaviateVectorStore
|
||||
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
|
||||
from llama_index.readers.weaviate.utils import (
|
||||
parse_get_response,
|
||||
validate_client,
|
||||
)
|
||||
|
||||
|
||||
class WeaviateVectorStoreClient(BaseVectorStoreClient):
|
||||
|
||||
def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
|
||||
self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size)
|
||||
|
||||
def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
|
||||
|
||||
weaviate.connect.connection.has_grpc = grpc_enabled
|
||||
|
||||
client = weaviate.Client(
|
||||
url=endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 60),
|
||||
startup_period=None
|
||||
)
|
||||
|
||||
client.batch.configure(
|
||||
# `batch_size` takes an `int` value to enable auto-batching
|
||||
# (`None` is used for manual batching)
|
||||
batch_size=batch_size,
|
||||
# dynamically update the `batch_size` based on import speed
|
||||
dynamic=True,
|
||||
# `timeout_retries` takes an `int` value to retry on time outs
|
||||
timeout_retries=3,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
||||
index_struct = WeaviateIndexDict()
|
||||
|
||||
if self._client is None:
|
||||
raise Exception("Vector client is not initialized.")
|
||||
|
||||
# {"class_prefix": "Gpt_index_xxx"}
|
||||
class_prefix = config.get('class_prefix')
|
||||
if not class_prefix:
|
||||
raise Exception("class_prefix cannot be None.")
|
||||
|
||||
return GPTWeaviateEnhanceIndex(
|
||||
service_context=service_context,
|
||||
index_struct=index_struct,
|
||||
vector_store=WeaviateWithSimilaritiesVectorStore(
|
||||
weaviate_client=self._client,
|
||||
class_prefix=class_prefix
|
||||
)
|
||||
)
|
||||
|
||||
def to_index_config(self, index_id: str) -> dict:
|
||||
return {"class_prefix": index_id}
|
||||
|
||||
|
||||
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
|
||||
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes."""
|
||||
nodes = self.weaviate_query(
|
||||
self._client,
|
||||
self._class_prefix,
|
||||
query,
|
||||
)
|
||||
nodes = nodes[: query.similarity_top_k]
|
||||
node_idxs = [str(i) for i in range(len(nodes))]
|
||||
|
||||
similarities = []
|
||||
for node in nodes:
|
||||
similarities.append(node.extra_info['similarity'])
|
||||
del node.extra_info['similarity']
|
||||
|
||||
return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
|
||||
|
||||
def weaviate_query(
|
||||
self,
|
||||
client: Any,
|
||||
class_prefix: str,
|
||||
query_spec: VectorStoreQuery,
|
||||
) -> List[Node]:
|
||||
"""Convert to LlamaIndex list."""
|
||||
validate_client(client)
|
||||
|
||||
class_name = _class_name(class_prefix)
|
||||
prop_names = [p["name"] for p in NODE_SCHEMA]
|
||||
vector = query_spec.query_embedding
|
||||
|
||||
# build query
|
||||
query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
|
||||
if query_spec.mode == VectorStoreQueryMode.DEFAULT:
|
||||
_logger.debug("Using vector search")
|
||||
if vector is not None:
|
||||
query = query.with_near_vector(
|
||||
{
|
||||
"vector": vector,
|
||||
}
|
||||
)
|
||||
elif query_spec.mode == VectorStoreQueryMode.HYBRID:
|
||||
_logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
|
||||
query = query.with_hybrid(
|
||||
query=query_spec.query_str,
|
||||
alpha=query_spec.alpha,
|
||||
vector=vector,
|
||||
)
|
||||
query = query.with_limit(query_spec.similarity_top_k)
|
||||
_logger.debug(f"Using limit of {query_spec.similarity_top_k}")
|
||||
|
||||
# execute query
|
||||
query_result = query.do()
|
||||
|
||||
# parse results
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
results = [self._to_node(entry) for entry in entries]
|
||||
return results
|
||||
|
||||
def _to_node(self, entry: Dict) -> Node:
|
||||
"""Convert to Node."""
|
||||
extra_info_str = entry["extra_info"]
|
||||
if extra_info_str == "":
|
||||
extra_info = None
|
||||
else:
|
||||
extra_info = json.loads(extra_info_str)
|
||||
|
||||
if 'certainty' in entry['_additional']:
|
||||
if extra_info:
|
||||
extra_info['similarity'] = entry['_additional']['certainty']
|
||||
else:
|
||||
extra_info = {'similarity': entry['_additional']['certainty']}
|
||||
|
||||
node_info_str = entry["node_info"]
|
||||
if node_info_str == "":
|
||||
node_info = None
|
||||
else:
|
||||
node_info = json.loads(node_info_str)
|
||||
|
||||
relationships_str = entry["relationships"]
|
||||
relationships: Dict[DocumentRelationship, str]
|
||||
if relationships_str == "":
|
||||
relationships = field(default_factory=dict)
|
||||
else:
|
||||
relationships = {
|
||||
DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
|
||||
}
|
||||
|
||||
return Node(
|
||||
text=entry["text"],
|
||||
doc_id=entry["doc_id"],
|
||||
embedding=entry["_additional"]["vector"],
|
||||
extra_info=extra_info,
|
||||
node_info=node_info,
|
||||
relationships=relationships,
|
||||
)
|
||||
|
||||
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
|
||||
"""Delete a document.
|
||||
|
||||
Args:
|
||||
doc_id (str): document id
|
||||
|
||||
"""
|
||||
delete_document(self._client, doc_id, self._class_prefix)
|
||||
|
||||
def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete node from the index.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
delete_node(self._client, node_id, self._class_prefix)
|
||||
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
"""
|
||||
Get node from the index by node id.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
entry = get_by_node_id(self._client, node_id, self._class_prefix)
|
||||
return True if entry else False
|
||||
|
||||
|
||||
class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
|
||||
pass
|
||||
|
||||
|
||||
def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
|
||||
"""Delete entry."""
|
||||
validate_client(client)
|
||||
# make sure that each entry
|
||||
class_name = _class_name(class_prefix)
|
||||
where_filter = {
|
||||
"path": ["ref_doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueString": ref_doc_id,
|
||||
}
|
||||
query = (
|
||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
||||
)
|
||||
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
for entry in entries:
|
||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
||||
|
||||
while len(entries) > 0:
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
for entry in entries:
|
||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
||||
|
||||
|
||||
def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
|
||||
"""Delete entry."""
|
||||
validate_client(client)
|
||||
# make sure that each entry
|
||||
class_name = _class_name(class_prefix)
|
||||
where_filter = {
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueString": node_id,
|
||||
}
|
||||
query = (
|
||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
||||
)
|
||||
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
for entry in entries:
|
||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
||||
|
||||
|
||||
def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
|
||||
"""Delete entry."""
|
||||
validate_client(client)
|
||||
# make sure that each entry
|
||||
class_name = _class_name(class_prefix)
|
||||
where_filter = {
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueString": node_id,
|
||||
}
|
||||
query = (
|
||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
||||
)
|
||||
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
if len(entries) == 0:
|
||||
return None
|
||||
|
||||
return entries[0]
|
||||
@ -1,7 +0,0 @@
|
||||
from core.vector_store.vector_store import VectorStore
|
||||
|
||||
vector_store = VectorStore()
|
||||
|
||||
|
||||
def init_app(app):
|
||||
vector_store.init_app(app)
|
||||
Loading…
Reference in New Issue