Merge branch 'main' into feat/plugin
commit
c6f34f5c17
@ -0,0 +1,44 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to AnalyticDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
|
||||
)
|
||||
ANALYTICDB_ACCOUNT : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance."
|
||||
)
|
||||
ANALYTICDB_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password associated with the AnalyticDB account for authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The namespace within AnalyticDB for schema isolation."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
)
|
||||
@ -0,0 +1,39 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class MyScaleConfig(BaseModel):
|
||||
"""
|
||||
MyScale configs
|
||||
"""
|
||||
|
||||
MYSCALE_HOST: Optional[str] = Field(
|
||||
description='MyScale host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_PORT: Optional[PositiveInt] = Field(
|
||||
description='MyScale port',
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: Optional[str] = Field(
|
||||
description='MyScale user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: Optional[str] = Field(
|
||||
description='MyScale password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: Optional[str] = Field(
|
||||
description='MyScale database name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: Optional[str] = Field(
|
||||
description='MyScale fts index parameters',
|
||||
default=None,
|
||||
)
|
||||
@ -0,0 +1,4 @@
|
||||
TTS_AUTO_PLAY_TIMEOUT = 5
|
||||
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02
|
||||
@ -0,0 +1,135 @@
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class AudioTrunk:
|
||||
def __init__(self, status: str, audio):
|
||||
self.audio = audio
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user="responding_tts",
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
|
||||
def _process_future(future_queue, audio_queue):
|
||||
while True:
|
||||
try:
|
||||
future = future_queue.get()
|
||||
if future is None:
|
||||
break
|
||||
for audio in future.result():
|
||||
audio_base64 = base64.b64encode(bytes(audio))
|
||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(e)
|
||||
break
|
||||
audio_queue.put(AudioTrunk("finish", b''))
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ''
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self.match = re.compile(r'[。.!?]')
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
values = [voice.get('value') for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get('value')
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message):
|
||||
try:
|
||||
self._msg_queue.put(message)
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
def _runtime(self):
|
||||
future_queue = queue.Queue()
|
||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||
while True:
|
||||
try:
|
||||
message = self._msg_queue.get()
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
|
||||
self.model_instance, self.tenant_id, self.voice)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
self.msg_text += message.event.chunk.delta.message.content
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = ''.join(sentence_arr)
|
||||
futures_result = self.executor.submit(_invoiceTTS, text_content,
|
||||
self.model_instance,
|
||||
self.tenant_id,
|
||||
self.voice)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ''
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
return self.last_message
|
||||
audio = self._audio_queue.get_nowait()
|
||||
if audio and audio.status == "finish":
|
||||
self.executor.shutdown(wait=False)
|
||||
self._runtime_thread = None
|
||||
if audio:
|
||||
self._last_audio_event = audio
|
||||
return audio
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
def _extract_sentence(self, org_text):
|
||||
tx = self.match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
result.append(org_text[start:end])
|
||||
start = end
|
||||
return result, org_text[start:]
|
||||
@ -0,0 +1 @@
|
||||
from .rate_limit import RateLimit
|
||||
@ -0,0 +1,120 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimit:
|
||||
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
|
||||
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
|
||||
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
|
||||
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict = {}
|
||||
|
||||
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
instance = super().__new__(cls)
|
||||
cls._instance_dict[client_id] = instance
|
||||
return cls._instance_dict[client_id]
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int):
|
||||
self.max_active_requests = max_active_requests
|
||||
if hasattr(self, 'initialized'):
|
||||
return
|
||||
self.initialized = True
|
||||
self.client_id = client_id
|
||||
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.last_recalculate_time = float('-inf')
|
||||
self.flush_cache(use_local_value=True)
|
||||
|
||||
def flush_cache(self, use_local_value=False):
|
||||
self.last_recalculate_time = time.time()
|
||||
# flush max active requests
|
||||
if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
||||
with redis_client.pipeline() as pipe:
|
||||
pipe.set(self.max_active_requests_key, self.max_active_requests)
|
||||
pipe.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
pipe.execute()
|
||||
else:
|
||||
with redis_client.pipeline() as pipe:
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
|
||||
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
|
||||
# flush max active requests (in-transit request list)
|
||||
if not redis_client.exists(self.active_requests_key):
|
||||
return
|
||||
request_details = redis_client.hgetall(self.active_requests_key)
|
||||
redis_client.expire(self.active_requests_key, timedelta(days=1))
|
||||
timeout_requests = [k for k, v in request_details.items() if
|
||||
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
|
||||
if timeout_requests:
|
||||
redis_client.hdel(self.active_requests_key, *timeout_requests)
|
||||
|
||||
def enter(self, request_id: Optional[str] = None) -> str:
|
||||
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
|
||||
self.flush_cache()
|
||||
if self.max_active_requests <= 0:
|
||||
return RateLimit._UNLIMITED_REQUEST_ID
|
||||
if not request_id:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
|
||||
active_requests_count = redis_client.hlen(self.active_requests_key)
|
||||
if active_requests_count >= self.max_active_requests:
|
||||
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
|
||||
"concurrent requests allowed is {}.".format(self.max_active_requests))
|
||||
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
|
||||
return request_id
|
||||
|
||||
def exit(self, request_id: str):
|
||||
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
|
||||
return
|
||||
redis_client.hdel(self.active_requests_key, request_id)
|
||||
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
|
||||
if isinstance(generator, dict):
|
||||
return generator
|
||||
else:
|
||||
return RateLimitGenerator(self, generator, request_id)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
|
||||
self.rate_limit = rate_limit
|
||||
if callable(generator):
|
||||
self.generator = generator()
|
||||
else:
|
||||
self.generator = generator
|
||||
self.request_id = request_id
|
||||
self.closed = False
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.closed:
|
||||
raise StopIteration
|
||||
try:
|
||||
return next(self.generator)
|
||||
except StopIteration:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
self.rate_limit.exit(self.request_id)
|
||||
if self.generator is not None and hasattr(self.generator, 'close'):
|
||||
self.generator.close()
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 21 KiB |
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 48 KiB |
@ -0,0 +1,6 @@
|
||||
- Qwen2-72B-Instruct-GPTQ-Int4
|
||||
- Qwen2-7B
|
||||
- Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
- Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
- Qwen1.5-7B
|
||||
- Qwen-14B-Chat-Int4
|
||||
@ -0,0 +1,110 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import tiktoken
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
|
||||
|
||||
|
||||
class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_string(self, model: str, text: str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
num_tokens += len(encoding.encode(t_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['openai_api_key']=credentials['api_key']
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
credentials['openai_api_base']='https://cloud.perfxlab.cn'
|
||||
else:
|
||||
parsed_url = urlparse(credentials['endpoint_url'])
|
||||
credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
@ -0,0 +1,32 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PerfXCloudProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `Qwen2_72B_Chat_GPTQ_Int4` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(
|
||||
model='Qwen2-72B-Instruct-GPTQ-Int4',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
@ -0,0 +1,4 @@
|
||||
model: BAAI/bge-m3
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 32768
|
||||
@ -0,0 +1,250 @@
|
||||
import json
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
|
||||
|
||||
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an OpenAI API-compatible text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
# Prepare headers and payload for the request
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
endpoint_url='https://cloud.perfxlab.cn/v1/'
|
||||
else:
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, 'embeddings')
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
|
||||
extra_model_kwargs['encoding_format'] = 'float'
|
||||
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
# TODO: Optimize for better token estimation and chunking
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0: cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# Prepare the payload for the request
|
||||
payload = {
|
||||
'input': inputs[i: i + max_chunks],
|
||||
'model': model,
|
||||
**extra_model_kwargs
|
||||
}
|
||||
|
||||
# Make the request to the OpenAI API
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
response_data = response.json()
|
||||
|
||||
# Extract embeddings and used tokens from the response
|
||||
embeddings_batch = [data['embedding'] for data in response_data['data']]
|
||||
embedding_used_tokens = response_data['usage']['total_tokens']
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
endpoint_url='https://cloud.perfxlab.cn/v1/'
|
||||
else:
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, 'embeddings')
|
||||
|
||||
payload = {
|
||||
'input': 'ping',
|
||||
'model': model
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=endpoint_url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed with status code {response.status_code}')
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
|
||||
|
||||
if 'model' not in json_result:
|
||||
raise CredentialsValidateFailedError(
|
||||
'Credentials validation failed: invalid response')
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get('input_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
)
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
@ -0,0 +1,40 @@
|
||||
model: ernie-4.0-turbo-8k-preview
|
||||
label:
|
||||
en_US: Ernie-4.0-turbo-8k-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.8
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 2048
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: disable_search
|
||||
label:
|
||||
zh_Hans: 禁用搜索
|
||||
en_US: Disable Search
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
||||
@ -0,0 +1,332 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = ("dify",)
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = ("cosine",)
|
||||
read_timeout: int = 60000
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
_instance = None
|
||||
_init = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
||||
# collection_name must be updated every time
|
||||
self._collection_name = collection_name.lower()
|
||||
if AnalyticdbVector._init:
|
||||
return
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(
|
||||
user_agent="dify", **config.to_analyticdb_client_params()
|
||||
)
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
AnalyticdbVector._init = True
|
||||
|
||||
def _initialize(self) -> None:
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create namespace {self.config.namespace}: {e}"
|
||||
)
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create collection {self._collection_name}: {e}"
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(
|
||||
self, documents: list[Document], embeddings: list[list[float]], **kwargs
|
||||
):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'"
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(
|
||||
self, query_vector: list[float], **kwargs: Any
|
||||
) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
metadata=json.loads(match.metadata.get("metadata_")),
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
metadata=json.loads(match.metadata.get("metadata_")),
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"][
|
||||
"class_prefix"
|
||||
]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
|
||||
)
|
||||
config = current_app.config
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
AnalyticdbConfig(
|
||||
access_key_id=config.get("ANALYTICDB_KEY_ID"),
|
||||
access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
|
||||
region_id=config.get("ANALYTICDB_REGION_ID"),
|
||||
instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
|
||||
account=config.get("ANALYTICDB_ACCOUNT"),
|
||||
account_password=config.get("ANALYTICDB_PASSWORD"),
|
||||
namespace=config.get("ANALYTICDB_NAMESPACE"),
|
||||
namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
|
||||
),
|
||||
)
|
||||
@ -0,0 +1,170 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from clickhouse_connect import get_client
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MyScaleConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
fts_params: str
|
||||
|
||||
|
||||
class SortOrder(Enum):
|
||||
ASC = "ASC"
|
||||
DESC = "DESC"
|
||||
|
||||
|
||||
class MyScaleVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._metric = metric
|
||||
self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC
|
||||
self._client = get_client(
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
username=config.user,
|
||||
password=config.password,
|
||||
)
|
||||
self._client.command("SET allow_experimental_object_type=1")
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MYSCALE
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
logging.info(f"create MyScale collection {self._collection_name} with dimension {dimension}")
|
||||
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
|
||||
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
|
||||
sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}(
|
||||
id String,
|
||||
text String,
|
||||
vector Array(Float32),
|
||||
metadata JSON,
|
||||
CONSTRAINT cons_vec_len CHECK length(vector) = {dimension},
|
||||
VECTOR INDEX vidx vector TYPE DEFAULT('metric_type = {self._metric}'),
|
||||
INDEX text_idx text TYPE fts{fts_params}
|
||||
) ENGINE = MergeTree ORDER BY id
|
||||
"""
|
||||
self._client.command(sql)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = []
|
||||
columns = ["id", "text", "vector", "metadata"]
|
||||
values = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
row = (
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {}
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
sql = f"""
|
||||
INSERT INTO {self._config.database}.{self._collection_name}
|
||||
({",".join(columns)}) VALUES {",".join(values)}
|
||||
"""
|
||||
self._client.command(sql)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def escape_str(value: Any) -> str:
|
||||
return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value))
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
||||
return results.row_count > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
rows = self._client.query(
|
||||
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
).result_rows
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs)
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
where_str = f"WHERE dist < {1 - score_threshold}" if \
|
||||
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
|
||||
sql = f"""
|
||||
SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
"""
|
||||
try:
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
]
|
||||
except Exception as e:
|
||||
logging.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
|
||||
|
||||
|
||||
class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return MyScaleVector(
|
||||
collection_name=collection_name,
|
||||
config=MyScaleConfig(
|
||||
host=config.get("MYSCALE_HOST", "localhost"),
|
||||
port=int(config.get("MYSCALE_PORT", 8123)),
|
||||
user=config.get("MYSCALE_USER", "default"),
|
||||
password=config.get("MYSCALE_PASSWORD", ""),
|
||||
database=config.get("MYSCALE_DATABASE", "default"),
|
||||
fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
|
||||
),
|
||||
)
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 22 KiB |
@ -0,0 +1,27 @@
|
||||
""" Provide the input parameters type for the cogview provider class """
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.cogview.tools.cogview3 import CogView3Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class COGVIEWProvider(BuiltinToolProviderController):
|
||||
""" cogview provider """
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
CogView3Tool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
|
||||
"size": "square",
|
||||
"n": 1
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e)) from e
|
||||
|
||||
@ -0,0 +1,61 @@
|
||||
identity:
|
||||
author: Waffle
|
||||
name: cogview
|
||||
label:
|
||||
en_US: CogView
|
||||
zh_Hans: CogView 绘画
|
||||
pt_BR: CogView
|
||||
description:
|
||||
en_US: CogView art
|
||||
zh_Hans: CogView 绘画
|
||||
pt_BR: CogView art
|
||||
icon: icon.png
|
||||
tags:
|
||||
- image
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
zhipuai_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: ZhipuAI API key
|
||||
zh_Hans: ZhipuAI API key
|
||||
pt_BR: ZhipuAI API key
|
||||
help:
|
||||
en_US: Please input your ZhipuAI API key
|
||||
zh_Hans: 请输入你的 ZhipuAI API key
|
||||
pt_BR: Please input your ZhipuAI API key
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI API key
|
||||
zh_Hans: 请输入你的 ZhipuAI API key
|
||||
pt_BR: Please input your ZhipuAI API key
|
||||
zhipuai_organizaion_id:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: ZhipuAI organization ID
|
||||
zh_Hans: ZhipuAI organization ID
|
||||
pt_BR: ZhipuAI organization ID
|
||||
help:
|
||||
en_US: Please input your ZhipuAI organization ID
|
||||
zh_Hans: 请输入你的 ZhipuAI organization ID
|
||||
pt_BR: Please input your ZhipuAI organization ID
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI organization ID
|
||||
zh_Hans: 请输入你的 ZhipuAI organization ID
|
||||
pt_BR: Please input your ZhipuAI organization ID
|
||||
zhipuai_base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: ZhipuAI base URL
|
||||
zh_Hans: ZhipuAI base URL
|
||||
pt_BR: ZhipuAI base URL
|
||||
help:
|
||||
en_US: Please input your ZhipuAI base URL
|
||||
zh_Hans: 请输入你的 ZhipuAI base URL
|
||||
pt_BR: Please input your ZhipuAI base URL
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI base URL
|
||||
zh_Hans: 请输入你的 ZhipuAI base URL
|
||||
pt_BR: Please input your ZhipuAI base URL
|
||||
@ -0,0 +1,69 @@
|
||||
import random
|
||||
from typing import Any, Union
|
||||
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CogView3Tool(BuiltinTool):
|
||||
""" CogView3 Tool """
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke CogView3 tool
|
||||
"""
|
||||
client = ZhipuAI(
|
||||
base_url=self.runtime.credentials['zhipuai_base_url'],
|
||||
api_key=self.runtime.credentials['zhipuai_api_key'],
|
||||
)
|
||||
size_mapping = {
|
||||
'square': '1024x1024',
|
||||
'vertical': '1024x1792',
|
||||
'horizontal': '1792x1024',
|
||||
}
|
||||
# prompt
|
||||
prompt = tool_parameters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
# get size
|
||||
print(tool_parameters.get('prompt', 'square'))
|
||||
size = size_mapping[tool_parameters.get('size', 'square')]
|
||||
# get n
|
||||
n = tool_parameters.get('n', 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get('quality', 'standard')
|
||||
if quality not in ['standard', 'hd']:
|
||||
return self.create_text_message('Invalid quality')
|
||||
# get style
|
||||
style = tool_parameters.get('style', 'vivid')
|
||||
if style not in ['natural', 'vivid']:
|
||||
return self.create_text_message('Invalid style')
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
||||
extra_body = {'seed': seed_id}
|
||||
response = client.images.generations(
|
||||
prompt=prompt,
|
||||
model="cogview-3",
|
||||
size=size,
|
||||
n=n,
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format='b64_json'
|
||||
)
|
||||
result = []
|
||||
for image in response.data:
|
||||
result.append(self.create_image_message(image=image.url))
|
||||
result.append(self.create_text_message(
|
||||
f'\nGenerate image source to Seed ID: {seed_id}'))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
random_id = ''.join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue