|
|
|
|
@ -3,12 +3,13 @@ import json
|
|
|
|
|
import math
|
|
|
|
|
import random
|
|
|
|
|
import string
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
|
import uuid
|
|
|
|
|
|
|
|
|
|
import click
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
from flask import current_app
|
|
|
|
|
from flask import current_app, Flask
|
|
|
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
|
from werkzeug.exceptions import NotFound
|
|
|
|
|
|
|
|
|
|
@ -456,92 +457,92 @@ def update_qdrant_indexes():
|
|
|
|
|
@click.command('normalization-collections', help='restore all collections in one')
|
|
|
|
|
def normalization_collections():
|
|
|
|
|
click.echo(click.style('Start normalization collections.', fg='green'))
|
|
|
|
|
normalization_count = 0
|
|
|
|
|
|
|
|
|
|
normalization_count = []
|
|
|
|
|
page = 1
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
|
|
|
|
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
|
|
|
|
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
|
|
|
|
|
except NotFound:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
datasets_result = datasets.items
|
|
|
|
|
page += 1
|
|
|
|
|
for dataset in datasets:
|
|
|
|
|
if not dataset.collection_binding_id:
|
|
|
|
|
try:
|
|
|
|
|
click.echo('restore dataset index: {}'.format(dataset.id))
|
|
|
|
|
try:
|
|
|
|
|
embedding_model = ModelFactory.get_embedding_model(
|
|
|
|
|
tenant_id=dataset.tenant_id,
|
|
|
|
|
model_provider_name=dataset.embedding_model_provider,
|
|
|
|
|
model_name=dataset.embedding_model
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
provider = Provider(
|
|
|
|
|
id='provider_id',
|
|
|
|
|
tenant_id=dataset.tenant_id,
|
|
|
|
|
provider_name='openai',
|
|
|
|
|
provider_type=ProviderType.CUSTOM.value,
|
|
|
|
|
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
|
|
|
|
is_valid=True,
|
|
|
|
|
)
|
|
|
|
|
model_provider = OpenAIProvider(provider=provider)
|
|
|
|
|
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
|
|
|
|
model_provider=model_provider)
|
|
|
|
|
embeddings = CacheEmbedding(embedding_model)
|
|
|
|
|
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
|
|
|
|
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
|
|
|
|
DatasetCollectionBinding.model_name == embedding_model.name). \
|
|
|
|
|
order_by(DatasetCollectionBinding.created_at). \
|
|
|
|
|
first()
|
|
|
|
|
|
|
|
|
|
if not dataset_collection_binding:
|
|
|
|
|
dataset_collection_binding = DatasetCollectionBinding(
|
|
|
|
|
provider_name=embedding_model.model_provider.provider_name,
|
|
|
|
|
model_name=embedding_model.name,
|
|
|
|
|
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
|
|
|
|
)
|
|
|
|
|
db.session.add(dataset_collection_binding)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
for i in range(0, len(datasets_result), 5):
|
|
|
|
|
threads = []
|
|
|
|
|
sub_datasets = datasets_result[i:i + 5]
|
|
|
|
|
for dataset in sub_datasets:
|
|
|
|
|
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
|
|
|
|
|
'flask_app': current_app._get_current_object(),
|
|
|
|
|
'dataset': dataset,
|
|
|
|
|
'normalization_count': normalization_count
|
|
|
|
|
})
|
|
|
|
|
threads.append(document_format_thread)
|
|
|
|
|
document_format_thread.start()
|
|
|
|
|
for thread in threads:
|
|
|
|
|
thread.join()
|
|
|
|
|
|
|
|
|
|
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
|
|
|
|
|
with flask_app.app_context():
|
|
|
|
|
try:
|
|
|
|
|
click.echo('restore dataset index: {}'.format(dataset.id))
|
|
|
|
|
try:
|
|
|
|
|
embedding_model = ModelFactory.get_embedding_model(
|
|
|
|
|
tenant_id=dataset.tenant_id,
|
|
|
|
|
model_provider_name=dataset.embedding_model_provider,
|
|
|
|
|
model_name=dataset.embedding_model
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
provider = Provider(
|
|
|
|
|
id='provider_id',
|
|
|
|
|
tenant_id=dataset.tenant_id,
|
|
|
|
|
provider_name='openai',
|
|
|
|
|
provider_type=ProviderType.CUSTOM.value,
|
|
|
|
|
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
|
|
|
|
is_valid=True,
|
|
|
|
|
)
|
|
|
|
|
model_provider = OpenAIProvider(provider=provider)
|
|
|
|
|
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
|
|
|
|
model_provider=model_provider)
|
|
|
|
|
embeddings = CacheEmbedding(embedding_model)
|
|
|
|
|
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
|
|
|
|
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
|
|
|
|
DatasetCollectionBinding.model_name == embedding_model.name). \
|
|
|
|
|
order_by(DatasetCollectionBinding.created_at). \
|
|
|
|
|
first()
|
|
|
|
|
|
|
|
|
|
if not dataset_collection_binding:
|
|
|
|
|
dataset_collection_binding = DatasetCollectionBinding(
|
|
|
|
|
provider_name=embedding_model.model_provider.provider_name,
|
|
|
|
|
model_name=embedding_model.name,
|
|
|
|
|
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
|
|
|
|
)
|
|
|
|
|
db.session.add(dataset_collection_binding)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
|
|
|
|
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
|
|
|
|
|
|
|
|
|
index = QdrantVectorIndex(
|
|
|
|
|
dataset=dataset,
|
|
|
|
|
config=QdrantConfig(
|
|
|
|
|
endpoint=current_app.config.get('QDRANT_URL'),
|
|
|
|
|
api_key=current_app.config.get('QDRANT_API_KEY'),
|
|
|
|
|
root_path=current_app.root_path
|
|
|
|
|
),
|
|
|
|
|
embeddings=embeddings
|
|
|
|
|
)
|
|
|
|
|
if index:
|
|
|
|
|
index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
|
|
|
|
else:
|
|
|
|
|
click.echo('passed.')
|
|
|
|
|
|
|
|
|
|
original_index = QdrantVectorIndex(
|
|
|
|
|
dataset=dataset,
|
|
|
|
|
config=QdrantConfig(
|
|
|
|
|
endpoint=current_app.config.get('QDRANT_URL'),
|
|
|
|
|
api_key=current_app.config.get('QDRANT_API_KEY'),
|
|
|
|
|
root_path=current_app.root_path
|
|
|
|
|
),
|
|
|
|
|
embeddings=embeddings
|
|
|
|
|
)
|
|
|
|
|
if original_index:
|
|
|
|
|
original_index.delete_original_collection(dataset, dataset_collection_binding)
|
|
|
|
|
normalization_count += 1
|
|
|
|
|
else:
|
|
|
|
|
click.echo('passed.')
|
|
|
|
|
except Exception as e:
|
|
|
|
|
click.echo(
|
|
|
|
|
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
|
|
|
|
fg='red'))
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
|
|
|
|
|
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
|
|
|
|
|
|
|
|
|
index = QdrantVectorIndex(
|
|
|
|
|
dataset=dataset,
|
|
|
|
|
config=QdrantConfig(
|
|
|
|
|
endpoint=current_app.config.get('QDRANT_URL'),
|
|
|
|
|
api_key=current_app.config.get('QDRANT_API_KEY'),
|
|
|
|
|
root_path=current_app.root_path
|
|
|
|
|
),
|
|
|
|
|
embeddings=embeddings
|
|
|
|
|
)
|
|
|
|
|
if index:
|
|
|
|
|
# index.delete_by_group_id(dataset.id)
|
|
|
|
|
index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
|
|
|
|
else:
|
|
|
|
|
click.echo('passed.')
|
|
|
|
|
normalization_count.append(1)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
click.echo(
|
|
|
|
|
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
|
|
|
|
fg='red'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
|
|
|
|
|
|