|
|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
import json
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
from flask import current_app
|
|
|
|
|
@ -8,9 +8,23 @@ from core.model_manager import ModelManager
|
|
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
|
|
|
from core.rag.datasource.entity.embedding import Embeddings
|
|
|
|
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
|
|
|
|
from core.rag.datasource.vdb.vector_type import VectorType
|
|
|
|
|
from core.rag.models.document import Document
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
|
from models.dataset import Dataset, DatasetCollectionBinding
|
|
|
|
|
from models.dataset import Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AbstractVectorFactory(ABC):
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
|
|
|
|
|
index_struct_dict = {
|
|
|
|
|
"type": vector_type,
|
|
|
|
|
"vector_store": {"class_prefix": collection_name}
|
|
|
|
|
}
|
|
|
|
|
return index_struct_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Vector:
|
|
|
|
|
@ -32,188 +46,35 @@ class Vector:
|
|
|
|
|
if not vector_type:
|
|
|
|
|
raise ValueError("Vector store must be specified.")
|
|
|
|
|
|
|
|
|
|
if vector_type == "weaviate":
|
|
|
|
|
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
|
|
|
|
if self._dataset.index_struct_dict:
|
|
|
|
|
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
|
|
collection_name = class_prefix
|
|
|
|
|
else:
|
|
|
|
|
dataset_id = self._dataset.id
|
|
|
|
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
|
|
index_struct_dict = {
|
|
|
|
|
"type": 'weaviate',
|
|
|
|
|
"vector_store": {"class_prefix": collection_name}
|
|
|
|
|
}
|
|
|
|
|
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
|
|
return WeaviateVector(
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
config=WeaviateConfig(
|
|
|
|
|
endpoint=config.get('WEAVIATE_ENDPOINT'),
|
|
|
|
|
api_key=config.get('WEAVIATE_API_KEY'),
|
|
|
|
|
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
|
|
|
|
),
|
|
|
|
|
attributes=self._attributes
|
|
|
|
|
)
|
|
|
|
|
elif vector_type == "qdrant":
|
|
|
|
|
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
|
|
|
|
if self._dataset.collection_binding_id:
|
|
|
|
|
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
|
|
|
|
filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
|
|
|
|
|
one_or_none()
|
|
|
|
|
if dataset_collection_binding:
|
|
|
|
|
collection_name = dataset_collection_binding.collection_name
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError('Dataset Collection Bindings is not exist!')
|
|
|
|
|
else:
|
|
|
|
|
if self._dataset.index_struct_dict:
|
|
|
|
|
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
|
|
collection_name = class_prefix
|
|
|
|
|
else:
|
|
|
|
|
dataset_id = self._dataset.id
|
|
|
|
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
|
|
|
|
|
|
|
if not self._dataset.index_struct_dict:
|
|
|
|
|
index_struct_dict = {
|
|
|
|
|
"type": 'qdrant',
|
|
|
|
|
"vector_store": {"class_prefix": collection_name}
|
|
|
|
|
}
|
|
|
|
|
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
|
|
|
|
|
|
|
return QdrantVector(
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
group_id=self._dataset.id,
|
|
|
|
|
config=QdrantConfig(
|
|
|
|
|
endpoint=config.get('QDRANT_URL'),
|
|
|
|
|
api_key=config.get('QDRANT_API_KEY'),
|
|
|
|
|
root_path=current_app.root_path,
|
|
|
|
|
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
|
|
|
|
|
grpc_port=config.get('QDRANT_GRPC_PORT'),
|
|
|
|
|
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
elif vector_type == "milvus":
|
|
|
|
|
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
|
|
|
|
if self._dataset.index_struct_dict:
|
|
|
|
|
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
|
|
collection_name = class_prefix
|
|
|
|
|
else:
|
|
|
|
|
dataset_id = self._dataset.id
|
|
|
|
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
|
|
index_struct_dict = {
|
|
|
|
|
"type": 'milvus',
|
|
|
|
|
"vector_store": {"class_prefix": collection_name}
|
|
|
|
|
}
|
|
|
|
|
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
|
|
return MilvusVector(
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
config=MilvusConfig(
|
|
|
|
|
host=config.get('MILVUS_HOST'),
|
|
|
|
|
port=config.get('MILVUS_PORT'),
|
|
|
|
|
user=config.get('MILVUS_USER'),
|
|
|
|
|
password=config.get('MILVUS_PASSWORD'),
|
|
|
|
|
secure=config.get('MILVUS_SECURE'),
|
|
|
|
|
database=config.get('MILVUS_DATABASE'),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
elif vector_type == "relyt":
|
|
|
|
|
from core.rag.datasource.vdb.relyt.relyt_vector import RelytConfig, RelytVector
|
|
|
|
|
if self._dataset.index_struct_dict:
|
|
|
|
|
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
|
|
collection_name = class_prefix
|
|
|
|
|
else:
|
|
|
|
|
dataset_id = self._dataset.id
|
|
|
|
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
|
|
index_struct_dict = {
|
|
|
|
|
"type": 'relyt',
|
|
|
|
|
"vector_store": {"class_prefix": collection_name}
|
|
|
|
|
}
|
|
|
|
|
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
|
|
return RelytVector(
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
config=RelytConfig(
|
|
|
|
|
host=config.get('RELYT_HOST'),
|
|
|
|
|
port=config.get('RELYT_PORT'),
|
|
|
|
|
user=config.get('RELYT_USER'),
|
|
|
|
|
password=config.get('RELYT_PASSWORD'),
|
|
|
|
|
database=config.get('RELYT_DATABASE'),
|
|
|
|
|
),
|
|
|
|
|
group_id=self._dataset.id
|
|
|
|
|
)
|
|
|
|
|
elif vector_type == "pgvecto_rs":
|
|
|
|
|
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
|
|
|
|
|
if self._dataset.index_struct_dict:
|
|
|
|
|
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
|
|
collection_name = class_prefix.lower()
|
|
|
|
|
else:
|
|
|
|
|
dataset_id = self._dataset.id
|
|
|
|
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
|
|
|
|
index_struct_dict = {
|
|
|
|
|
"type": 'pgvecto_rs',
|
|
|
|
|
"vector_store": {"class_prefix": collection_name}
|
|
|
|
|
}
|
|
|
|
|
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
|
|
dim = len(self._embeddings.embed_query("pgvecto_rs"))
|
|
|
|
|
return PGVectoRS(
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
config=PgvectoRSConfig(
|
|
|
|
|
host=config.get('PGVECTO_RS_HOST'),
|
|
|
|
|
port=config.get('PGVECTO_RS_PORT'),
|
|
|
|
|
user=config.get('PGVECTO_RS_USER'),
|
|
|
|
|
password=config.get('PGVECTO_RS_PASSWORD'),
|
|
|
|
|
database=config.get('PGVECTO_RS_DATABASE'),
|
|
|
|
|
),
|
|
|
|
|
dim=dim
|
|
|
|
|
)
|
|
|
|
|
elif vector_type == "pgvector":
|
|
|
|
|
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
|
|
|
|
|
|
|
|
|
|
if self._dataset.index_struct_dict:
|
|
|
|
|
class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"]
|
|
|
|
|
collection_name = class_prefix
|
|
|
|
|
else:
|
|
|
|
|
dataset_id = self._dataset.id
|
|
|
|
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
|
|
index_struct_dict = {
|
|
|
|
|
"type": "pgvector",
|
|
|
|
|
"vector_store": {"class_prefix": collection_name}}
|
|
|
|
|
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
|
|
return PGVector(
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
config=PGVectorConfig(
|
|
|
|
|
host=config.get("PGVECTOR_HOST"),
|
|
|
|
|
port=config.get("PGVECTOR_PORT"),
|
|
|
|
|
user=config.get("PGVECTOR_USER"),
|
|
|
|
|
password=config.get("PGVECTOR_PASSWORD"),
|
|
|
|
|
database=config.get("PGVECTOR_DATABASE"),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
elif vector_type == "tidb_vector":
|
|
|
|
|
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
|
|
|
|
|
|
|
|
|
|
if self._dataset.index_struct_dict:
|
|
|
|
|
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
|
|
collection_name = class_prefix.lower()
|
|
|
|
|
else:
|
|
|
|
|
dataset_id = self._dataset.id
|
|
|
|
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
|
|
|
|
index_struct_dict = {
|
|
|
|
|
"type": 'tidb_vector',
|
|
|
|
|
"vector_store": {"class_prefix": collection_name}
|
|
|
|
|
}
|
|
|
|
|
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
|
|
|
|
|
|
|
return TiDBVector(
|
|
|
|
|
collection_name=collection_name,
|
|
|
|
|
config=TiDBVectorConfig(
|
|
|
|
|
host=config.get('TIDB_VECTOR_HOST'),
|
|
|
|
|
port=config.get('TIDB_VECTOR_PORT'),
|
|
|
|
|
user=config.get('TIDB_VECTOR_USER'),
|
|
|
|
|
password=config.get('TIDB_VECTOR_PASSWORD'),
|
|
|
|
|
database=config.get('TIDB_VECTOR_DATABASE'),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
|
|
|
|
vector_factory_cls = self.get_vector_factory(vector_type)
|
|
|
|
|
return vector_factory_cls().init_vector(self._dataset, self._attributes, self._embeddings)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
|
|
|
|
match vector_type:
|
|
|
|
|
case VectorType.MILVUS:
|
|
|
|
|
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
|
|
|
|
return MilvusVectorFactory
|
|
|
|
|
case VectorType.PGVECTOR:
|
|
|
|
|
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
|
|
|
|
return PGVectorFactory
|
|
|
|
|
case VectorType.PGVECTO_RS:
|
|
|
|
|
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
|
|
|
|
|
return PGVectoRSFactory
|
|
|
|
|
case VectorType.QDRANT:
|
|
|
|
|
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
|
|
|
|
|
return QdrantVectorFactory
|
|
|
|
|
case VectorType.RELYT:
|
|
|
|
|
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
|
|
|
|
|
return RelytVectorFactory
|
|
|
|
|
case VectorType.TIDB_VECTOR:
|
|
|
|
|
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
|
|
|
|
return TiDBVectorFactory
|
|
|
|
|
case VectorType.WEAVIATE:
|
|
|
|
|
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
|
|
|
|
|
return WeaviateVectorFactory
|
|
|
|
|
case _:
|
|
|
|
|
raise ValueError(f"Vector store {vector_type} is not supported.")
|
|
|
|
|
|
|
|
|
|
def create(self, texts: list = None, **kwargs):
|
|
|
|
|
if texts:
|
|
|
|
|
|