Merge branch 'main' into fix/chore-fix
commit
196bfeaaf4
Binary file not shown.
|
After Width: | Height: | Size: 230 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 205 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 44 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 262 KiB |
@ -0,0 +1,173 @@
|
|||||||
|
## Predefined Model Integration
|
||||||
|
|
||||||
|
After completing the vendor integration, the next step is to integrate the models from the vendor.
|
||||||
|
|
||||||
|
First, we need to determine the type of model to be integrated and create the corresponding model type `module` under the respective vendor's directory.
|
||||||
|
|
||||||
|
Currently supported model types are:
|
||||||
|
|
||||||
|
- `llm` Text Generation Model
|
||||||
|
- `text_embedding` Text Embedding Model
|
||||||
|
- `rerank` Rerank Model
|
||||||
|
- `speech2text` Speech-to-Text
|
||||||
|
- `tts` Text-to-Speech
|
||||||
|
- `moderation` Moderation
|
||||||
|
|
||||||
|
Continuing with `Anthropic` as an example, `Anthropic` only supports LLM, so create a `module` named `llm` under `model_providers.anthropic`.
|
||||||
|
|
||||||
|
For predefined models, we first need to create a YAML file named after the model under the `llm` `module`, such as `claude-2.1.yaml`.
|
||||||
|
|
||||||
|
### Prepare Model YAML
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model: claude-2.1 # Model identifier
|
||||||
|
# Display name of the model, which can be set to en_US English or zh_Hans Chinese. If zh_Hans is not set, it will default to en_US.
|
||||||
|
# This can also be omitted, in which case the model identifier will be used as the label
|
||||||
|
label:
|
||||||
|
en_US: claude-2.1
|
||||||
|
model_type: llm # Model type, claude-2.1 is an LLM
|
||||||
|
features: # Supported features, agent-thought supports Agent reasoning, vision supports image understanding
|
||||||
|
- agent-thought
|
||||||
|
model_properties: # Model properties
|
||||||
|
mode: chat # LLM mode, complete for text completion models, chat for conversation models
|
||||||
|
context_size: 200000 # Maximum context size
|
||||||
|
parameter_rules: # Parameter rules for the model call; only LLM requires this
|
||||||
|
- name: temperature # Parameter variable name
|
||||||
|
# Five default configuration templates are provided: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
|
||||||
|
# The template variable name can be set directly in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
|
||||||
|
# Additional configuration parameters will override the default configuration if set
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label: # Display name of the parameter
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int # Parameter type, supports float/int/string/boolean
|
||||||
|
help: # Help information, describing the parameter's function
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false # Whether the parameter is mandatory; can be omitted
|
||||||
|
- name: max_tokens_to_sample
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 4096 # Default value of the parameter
|
||||||
|
min: 1 # Minimum value of the parameter, applicable to float/int only
|
||||||
|
max: 4096 # Maximum value of the parameter, applicable to float/int only
|
||||||
|
pricing: # Pricing information
|
||||||
|
input: '8.00' # Input unit price, i.e., prompt price
|
||||||
|
output: '24.00' # Output unit price, i.e., response content price
|
||||||
|
unit: '0.000001' # Price unit, meaning the above prices are per 100K
|
||||||
|
currency: USD # Price currency
|
||||||
|
```
|
||||||
|
|
||||||
|
It is recommended to prepare all model configurations before starting the implementation of the model code.
|
||||||
|
|
||||||
|
You can also refer to the YAML configuration information under the corresponding model type directories of other vendors in the `model_providers` directory. For the complete YAML rules, refer to: [Schema](schema.md#aimodelentity).
|
||||||
|
|
||||||
|
### Implement the Model Call Code
|
||||||
|
|
||||||
|
Next, create a Python file named `llm.py` under the `llm` `module` to write the implementation code.
|
||||||
|
|
||||||
|
Create an Anthropic LLM class named `AnthropicLargeLanguageModel` (or any other name), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
|
||||||
|
|
||||||
|
- LLM Call
|
||||||
|
|
||||||
|
Implement the core method for calling the LLM, supporting both streaming and synchronous responses.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True, user: Optional[str] = None) \
|
||||||
|
-> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Invoke large language model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
Ensure to use two functions for returning data, one for synchronous returns and the other for streaming returns, because Python identifies functions containing the `yield` keyword as generator functions, fixing the return type to `Generator`. Thus, synchronous and streaming returns need to be implemented separately, as shown below (note that the example uses simplified parameters, for actual implementation follow the above parameter list):
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _invoke(self, stream: bool, **kwargs) \
|
||||||
|
-> Union[LLMResult, Generator]:
|
||||||
|
if stream:
|
||||||
|
return self._handle_stream_response(**kwargs)
|
||||||
|
return self._handle_sync_response(**kwargs)
|
||||||
|
|
||||||
|
def _handle_stream_response(self, **kwargs) -> Generator:
|
||||||
|
for chunk in response:
|
||||||
|
yield chunk
|
||||||
|
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||||||
|
return LLMResult(**response)
|
||||||
|
```
|
||||||
|
|
||||||
|
- Pre-compute Input Tokens
|
||||||
|
|
||||||
|
If the model does not provide an interface to precompute tokens, return 0 directly.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
- Validate Model Credentials
|
||||||
|
|
||||||
|
Similar to vendor credential validation, but specific to a single model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
- Map Invoke Errors
|
||||||
|
|
||||||
|
When a model call fails, map it to a specific `InvokeError` type as required by Runtime, allowing Dify to handle different errors accordingly.
|
||||||
|
|
||||||
|
Runtime Errors:
|
||||||
|
|
||||||
|
- `InvokeConnectionError` Connection error
|
||||||
|
|
||||||
|
- `InvokeServerUnavailableError` Service provider unavailable
|
||||||
|
- `InvokeRateLimitError` Rate limit reached
|
||||||
|
- `InvokeAuthorizationError` Authorization failed
|
||||||
|
- `InvokeBadRequestError` Parameter error
|
||||||
|
|
||||||
|
```python
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
|
||||||
@ -1,238 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import boto3
|
|
||||||
from botocore.config import Config
|
|
||||||
from botocore.exceptions import (
|
|
||||||
ClientError,
|
|
||||||
EndpointConnectionError,
|
|
||||||
NoRegionError,
|
|
||||||
ServiceNotInRegionError,
|
|
||||||
UnknownServiceError,
|
|
||||||
)
|
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
|
||||||
from core.model_runtime.errors.invoke import (
|
|
||||||
InvokeAuthorizationError,
|
|
||||||
InvokeBadRequestError,
|
|
||||||
InvokeConnectionError,
|
|
||||||
InvokeError,
|
|
||||||
InvokeRateLimitError,
|
|
||||||
InvokeServerUnavailableError,
|
|
||||||
)
|
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
texts: list[str],
|
|
||||||
user: Optional[str] = None,
|
|
||||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
|
||||||
) -> TextEmbeddingResult:
|
|
||||||
"""
|
|
||||||
Invoke text embedding model
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param texts: texts to embed
|
|
||||||
:param user: unique user id
|
|
||||||
:param input_type: input type
|
|
||||||
:return: embeddings result
|
|
||||||
"""
|
|
||||||
client_config = Config(region_name=credentials["aws_region"])
|
|
||||||
|
|
||||||
bedrock_runtime = boto3.client(
|
|
||||||
service_name="bedrock-runtime",
|
|
||||||
config=client_config,
|
|
||||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
|
||||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = []
|
|
||||||
token_usage = 0
|
|
||||||
|
|
||||||
model_prefix = model.split(".")[0]
|
|
||||||
|
|
||||||
if model_prefix == "amazon":
|
|
||||||
for text in texts:
|
|
||||||
body = {
|
|
||||||
"inputText": text,
|
|
||||||
}
|
|
||||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
|
||||||
embeddings.extend([response_body.get("embedding")])
|
|
||||||
token_usage += response_body.get("inputTextTokenCount")
|
|
||||||
logger.warning(f"Total Tokens: {token_usage}")
|
|
||||||
result = TextEmbeddingResult(
|
|
||||||
model=model,
|
|
||||||
embeddings=embeddings,
|
|
||||||
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
if model_prefix == "cohere":
|
|
||||||
input_type = "search_document" if len(texts) > 1 else "search_query"
|
|
||||||
for text in texts:
|
|
||||||
body = {
|
|
||||||
"texts": [text],
|
|
||||||
"input_type": input_type,
|
|
||||||
}
|
|
||||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
|
||||||
embeddings.extend(response_body.get("embeddings"))
|
|
||||||
token_usage += len(text)
|
|
||||||
result = TextEmbeddingResult(
|
|
||||||
model=model,
|
|
||||||
embeddings=embeddings,
|
|
||||||
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
# others
|
|
||||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
|
||||||
"""
|
|
||||||
Get number of tokens for given prompt messages
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param texts: texts to embed
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
num_tokens = 0
|
|
||||||
for text in texts:
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
||||||
"""
|
|
||||||
Validate model credentials
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
|
||||||
"""
|
|
||||||
Map model invoke error to unified error
|
|
||||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
|
||||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
|
||||||
which needs to be converted into a unified error type for the caller.
|
|
||||||
|
|
||||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
InvokeConnectionError: [],
|
|
||||||
InvokeServerUnavailableError: [],
|
|
||||||
InvokeRateLimitError: [],
|
|
||||||
InvokeAuthorizationError: [],
|
|
||||||
InvokeBadRequestError: [],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _create_payload(
|
|
||||||
self,
|
|
||||||
model_prefix: str,
|
|
||||||
texts: list[str],
|
|
||||||
model_parameters: dict,
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create payload for bedrock api call depending on model provider
|
|
||||||
"""
|
|
||||||
payload = {}
|
|
||||||
|
|
||||||
if model_prefix == "amazon":
|
|
||||||
payload["inputText"] = texts
|
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
|
||||||
"""
|
|
||||||
Calculate response usage
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param tokens: input tokens
|
|
||||||
:return: usage
|
|
||||||
"""
|
|
||||||
# get input price info
|
|
||||||
input_price_info = self.get_price(
|
|
||||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = EmbeddingUsage(
|
|
||||||
tokens=tokens,
|
|
||||||
total_tokens=tokens,
|
|
||||||
unit_price=input_price_info.unit_price,
|
|
||||||
price_unit=input_price_info.unit,
|
|
||||||
total_price=input_price_info.total_amount,
|
|
||||||
currency=input_price_info.currency,
|
|
||||||
latency=time.perf_counter() - self.started_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
return usage
|
|
||||||
|
|
||||||
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
|
|
||||||
"""
|
|
||||||
Map client error to invoke error
|
|
||||||
|
|
||||||
:param error_code: error code
|
|
||||||
:param error_msg: error message
|
|
||||||
:return: invoke error
|
|
||||||
"""
|
|
||||||
|
|
||||||
if error_code == "AccessDeniedException":
|
|
||||||
return InvokeAuthorizationError(error_msg)
|
|
||||||
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
|
|
||||||
return InvokeBadRequestError(error_msg)
|
|
||||||
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
|
|
||||||
return InvokeRateLimitError(error_msg)
|
|
||||||
elif error_code in {
|
|
||||||
"ModelTimeoutException",
|
|
||||||
"ModelErrorException",
|
|
||||||
"InternalServerException",
|
|
||||||
"ModelNotReadyException",
|
|
||||||
}:
|
|
||||||
return InvokeServerUnavailableError(error_msg)
|
|
||||||
elif error_code == "ModelStreamErrorException":
|
|
||||||
return InvokeConnectionError(error_msg)
|
|
||||||
|
|
||||||
return InvokeError(error_msg)
|
|
||||||
|
|
||||||
def _invoke_bedrock_embedding(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
bedrock_runtime,
|
|
||||||
body: dict,
|
|
||||||
):
|
|
||||||
accept = "application/json"
|
|
||||||
content_type = "application/json"
|
|
||||||
try:
|
|
||||||
response = bedrock_runtime.invoke_model(
|
|
||||||
body=json.dumps(body), modelId=model, accept=accept, contentType=content_type
|
|
||||||
)
|
|
||||||
response_body = json.loads(response.get("body").read().decode("utf-8"))
|
|
||||||
return response_body
|
|
||||||
except ClientError as ex:
|
|
||||||
error_code = ex.response["Error"]["Code"]
|
|
||||||
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
|
||||||
raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
|
||||||
|
|
||||||
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
|
||||||
raise InvokeConnectionError(str(ex))
|
|
||||||
|
|
||||||
except UnknownServiceError as ex:
|
|
||||||
raise InvokeServerUnavailableError(str(ex))
|
|
||||||
|
|
||||||
except Exception as ex:
|
|
||||||
raise InvokeError(str(ex))
|
|
||||||
@ -1,223 +0,0 @@
|
|||||||
import json
|
|
||||||
import time
|
|
||||||
from decimal import Decimal
|
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
|
||||||
from core.model_runtime.entities.model_entities import (
|
|
||||||
AIModelEntity,
|
|
||||||
FetchFrom,
|
|
||||||
ModelPropertyKey,
|
|
||||||
ModelType,
|
|
||||||
PriceConfig,
|
|
||||||
PriceType,
|
|
||||||
)
|
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
|
||||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
|
||||||
|
|
||||||
|
|
||||||
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
|
||||||
"""
|
|
||||||
Model class for an OpenAI API-compatible text embedding model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
texts: list[str],
|
|
||||||
user: Optional[str] = None,
|
|
||||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
|
||||||
) -> TextEmbeddingResult:
|
|
||||||
"""
|
|
||||||
Invoke text embedding model
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param texts: texts to embed
|
|
||||||
:param user: unique user id
|
|
||||||
:param input_type: input type
|
|
||||||
:return: embeddings result
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Prepare headers and payload for the request
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
api_key = credentials.get("api_key")
|
|
||||||
if api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
|
||||||
|
|
||||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
|
||||||
endpoint_url = "https://cloud.perfxlab.cn/v1/"
|
|
||||||
else:
|
|
||||||
endpoint_url = credentials.get("endpoint_url")
|
|
||||||
if not endpoint_url.endswith("/"):
|
|
||||||
endpoint_url += "/"
|
|
||||||
|
|
||||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
|
||||||
if user:
|
|
||||||
extra_model_kwargs["user"] = user
|
|
||||||
|
|
||||||
extra_model_kwargs["encoding_format"] = "float"
|
|
||||||
|
|
||||||
# get model properties
|
|
||||||
context_size = self._get_context_size(model, credentials)
|
|
||||||
max_chunks = self._get_max_chunks(model, credentials)
|
|
||||||
|
|
||||||
inputs = []
|
|
||||||
indices = []
|
|
||||||
used_tokens = 0
|
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
|
||||||
# TODO: Optimize for better token estimation and chunking
|
|
||||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
|
||||||
|
|
||||||
if num_tokens >= context_size:
|
|
||||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
|
||||||
# if num tokens is larger than context length, only use the start
|
|
||||||
inputs.append(text[0:cutoff])
|
|
||||||
else:
|
|
||||||
inputs.append(text)
|
|
||||||
indices += [i]
|
|
||||||
|
|
||||||
batched_embeddings = []
|
|
||||||
_iter = range(0, len(inputs), max_chunks)
|
|
||||||
|
|
||||||
for i in _iter:
|
|
||||||
# Prepare the payload for the request
|
|
||||||
payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs}
|
|
||||||
|
|
||||||
# Make the request to the OpenAI API
|
|
||||||
response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
|
||||||
|
|
||||||
response.raise_for_status() # Raise an exception for HTTP errors
|
|
||||||
response_data = response.json()
|
|
||||||
|
|
||||||
# Extract embeddings and used tokens from the response
|
|
||||||
embeddings_batch = [data["embedding"] for data in response_data["data"]]
|
|
||||||
embedding_used_tokens = response_data["usage"]["total_tokens"]
|
|
||||||
|
|
||||||
used_tokens += embedding_used_tokens
|
|
||||||
batched_embeddings += embeddings_batch
|
|
||||||
|
|
||||||
# calc usage
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
|
||||||
|
|
||||||
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
|
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
|
||||||
"""
|
|
||||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param texts: texts to embed
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
||||||
"""
|
|
||||||
Validate model credentials
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
api_key = credentials.get("api_key")
|
|
||||||
|
|
||||||
if api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
|
||||||
|
|
||||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
|
||||||
endpoint_url = "https://cloud.perfxlab.cn/v1/"
|
|
||||||
else:
|
|
||||||
endpoint_url = credentials.get("endpoint_url")
|
|
||||||
if not endpoint_url.endswith("/"):
|
|
||||||
endpoint_url += "/"
|
|
||||||
|
|
||||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
|
||||||
|
|
||||||
payload = {"input": "ping", "model": model}
|
|
||||||
|
|
||||||
response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise CredentialsValidateFailedError(
|
|
||||||
f"Credentials validation failed with status code {response.status_code}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
json_result = response.json()
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
|
|
||||||
|
|
||||||
if "model" not in json_result:
|
|
||||||
raise CredentialsValidateFailedError("Credentials validation failed: invalid response")
|
|
||||||
except CredentialsValidateFailedError:
|
|
||||||
raise
|
|
||||||
except Exception as ex:
|
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
|
||||||
"""
|
|
||||||
generate custom model entities from credentials
|
|
||||||
"""
|
|
||||||
entity = AIModelEntity(
|
|
||||||
model=model,
|
|
||||||
label=I18nObject(en_US=model),
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
||||||
model_properties={
|
|
||||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
|
||||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
|
||||||
},
|
|
||||||
parameter_rules=[],
|
|
||||||
pricing=PriceConfig(
|
|
||||||
input=Decimal(credentials.get("input_price", 0)),
|
|
||||||
unit=Decimal(credentials.get("unit", 0)),
|
|
||||||
currency=credentials.get("currency", "USD"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return entity
|
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
|
||||||
"""
|
|
||||||
Calculate response usage
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param tokens: input tokens
|
|
||||||
:return: usage
|
|
||||||
"""
|
|
||||||
# get input price info
|
|
||||||
input_price_info = self.get_price(
|
|
||||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = EmbeddingUsage(
|
|
||||||
tokens=tokens,
|
|
||||||
total_tokens=tokens,
|
|
||||||
unit_price=input_price_info.unit_price,
|
|
||||||
price_unit=input_price_info.unit,
|
|
||||||
total_price=input_price_info.total_amount,
|
|
||||||
currency=input_price_info.currency,
|
|
||||||
latency=time.perf_counter() - self.started_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
return usage
|
|
||||||
@ -0,0 +1,21 @@
|
|||||||
|
<svg version="1.0" xmlns="http://www.w3.org/2000/svg" width="100.000000pt" height="19.000000pt" viewBox="0 0 300.000000 57.000000" preserveAspectRatio="xMidYMid meet"><g transform="translate(0.000000,57.000000) scale(0.100000,-0.100000)" fill="#000000" stroke="none"><path d="M2505 368 c-38 -84 -86 -188 -106 -230 l-38 -78 27 0 c24 0 30 7 55
|
||||||
|
75 l28 75 100 0 100 0 25 -55 c13 -31 24 -64 24 -75 0 -17 7 -20 44 -20 l43 0
|
||||||
|
-37 73 c-20 39 -68 143 -106 229 -38 87 -74 158 -80 158 -5 0 -41 -69 -79
|
||||||
|
-152z m110 -30 c22 -51 41 -95 42 -98 2 -3 -36 -6 -83 -7 -76 -1 -85 0 -81 15
|
||||||
|
12 40 72 182 77 182 3 0 24 -41 45 -92z"/><path d="M63 493 c19 -61 197 -438 209 -440 10 -2 147 282 216 449 2 4 -10 8
|
||||||
|
-27 8 -23 0 -31 -5 -31 -17 0 -16 -142 -365 -146 -360 -8 11 -144 329 -149
|
||||||
|
350 -6 23 -12 27 -42 27 -29 0 -34 -3 -30 -17z"/><path d="M2855 285 l0 -225 30 0 30 0 0 225 0 225 -30 0 -30 0 0 -225z"/><path d="M588 380 c-55 -30 -82 -74 -86 -145 -3 -50 0 -66 20 -95 39 -58 82
|
||||||
|
-80 153 -80 68 0 110 21 149 73 32 43 30 150 -3 196 -47 66 -158 90 -233 51z
|
||||||
|
m133 -16 c59 -30 89 -156 54 -224 -45 -87 -162 -78 -201 16 -18 44 -18 128 1
|
||||||
|
164 28 55 90 73 146 44z"/><path d="M935 303 l76 -98 -7 -72 -6 -73 33 0 34 0 -3 78 -4 77 71 93 c65 85
|
||||||
|
68 92 46 92 -15 0 -29 -9 -36 -22 -18 -33 -90 -128 -98 -128 -6 1 -67 85 -88
|
||||||
|
122 -8 15 -24 23 -53 25 l-41 4 76 -98z"/><path d="M1257 230 c-82 -169 -83 -170 -57 -170 17 0 27 6 27 15 0 8 7 31 17
|
||||||
|
52 l17 38 79 0 78 1 16 -34 c9 -18 16 -42 16 -52 0 -17 7 -20 41 -20 22 0 39
|
||||||
|
3 37 8 -2 4 -39 80 -83 170 -43 89 -84 162 -92 162 -7 0 -50 -76 -96 -170z
|
||||||
|
m90 -38 c-33 -2 -61 -1 -63 1 -2 2 10 34 26 71 l31 68 33 -68 33 -69 -60 -3z"/><path d="M1665 386 c-37 -16 -84 -63 -97 -96 -13 -35 -12 -104 2 -132 49 -94
|
||||||
|
182 -134 280 -83 24 12 29 22 32 64 3 49 3 49 -30 53 l-33 4 3 -45 c4 -61 -5
|
||||||
|
-71 -60 -71 -93 0 -142 57 -142 164 0 44 5 60 25 85 47 55 136 65 184 20 30
|
||||||
|
-28 35 -20 11 19 -19 31 -22 32 -82 32 -35 -1 -76 -7 -93 -14z"/><path d="M1955 230 l0 -170 91 0 c76 0 93 3 98 16 4 9 5 18 4 20 -2 1 -31 -1
|
||||||
|
-66 -5 -34 -4 -64 -5 -67 -3 -3 3 -5 36 -5 73 l0 68 55 -6 c49 -5 55 -4 55 13
|
||||||
|
0 17 -6 19 -55 16 l-55 -4 0 61 0 61 64 0 c48 0 65 4 70 15 4 13 -10 15 -92
|
||||||
|
15 l-97 0 0 -170z"/></g></svg>
|
||||||
|
After Width: | Height: | Size: 2.2 KiB |
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="64px" height="64px" viewBox="0 0 64 64" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>voyage</title>
|
||||||
|
<g id="voyage" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<rect id="矩形" fill="#333333" x="0" y="0" width="64" height="64" rx="12"></rect>
|
||||||
|
<path d="M12.1128004,51.4376727 C13.8950799,45.8316747 30.5922254,11.1847688 31.7178757,11.0009656 C32.6559176,10.8171624 45.5070913,36.9172188 51.9795803,52.2647871 C52.1671887,52.6323936 51.0415384,53 49.4468672,53 C47.2893709,53 46.5389374,52.540492 46.5389374,51.4376727 C46.5389374,49.967247 33.2187427,17.8935861 32.8435259,18.3530942 C32.0930924,19.3640118 19.3357228,48.5887229 18.8667019,50.5186566 C18.3038768,52.6323936 17.7410516,53 14.926926,53 C12.2066045,53 11.7375836,52.7242952 12.1128004,51.4376727 Z" id="路径" fill="#FFFFFF" transform="translate(32, 32) scale(1, -1) translate(-32, -32)"></path>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.0 KiB |
@ -0,0 +1,4 @@
|
|||||||
|
model: rerank-1
|
||||||
|
model_type: rerank
|
||||||
|
model_properties:
|
||||||
|
context_size: 8000
|
||||||
@ -0,0 +1,4 @@
|
|||||||
|
model: rerank-lite-1
|
||||||
|
model_type: rerank
|
||||||
|
model_properties:
|
||||||
|
context_size: 4000
|
||||||
@ -0,0 +1,123 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||||
|
|
||||||
|
|
||||||
|
class VoyageRerankModel(RerankModel):
|
||||||
|
"""
|
||||||
|
Model class for Voyage rerank model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
query: str,
|
||||||
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n documents to return
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
if len(docs) == 0:
|
||||||
|
return RerankResult(model=model, docs=[])
|
||||||
|
|
||||||
|
base_url = credentials.get("base_url", "https://api.voyageai.com/v1")
|
||||||
|
base_url = base_url.removesuffix("/")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = httpx.post(
|
||||||
|
base_url + "/rerank",
|
||||||
|
json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True},
|
||||||
|
headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
results = response.json()
|
||||||
|
|
||||||
|
rerank_documents = []
|
||||||
|
for result in results["data"]:
|
||||||
|
rerank_document = RerankDocument(
|
||||||
|
index=result["index"],
|
||||||
|
text=result["document"],
|
||||||
|
score=result["relevance_score"],
|
||||||
|
)
|
||||||
|
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
|
return RerankResult(model=model, docs=rerank_documents)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8,
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [httpx.ConnectError],
|
||||||
|
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||||
|
InvokeRateLimitError: [],
|
||||||
|
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||||
|
InvokeBadRequestError: [httpx.RequestError],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
"""
|
||||||
|
generate custom model entities from credentials
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
model_type=ModelType.RERANK,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))},
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
||||||
@ -0,0 +1,172 @@
|
|||||||
|
import time
|
||||||
|
from json import JSONDecodeError, dumps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from core.embedding.embedding_constant import EmbeddingInputType
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
class VoyageTextEmbeddingModel(TextEmbeddingModel):
|
||||||
|
"""
|
||||||
|
Model class for Voyage text embedding model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_base: str = "https://api.voyageai.com/v1"
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:param input_type: input type
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
api_key = credentials["api_key"]
|
||||||
|
if not api_key:
|
||||||
|
raise CredentialsValidateFailedError("api_key is required")
|
||||||
|
|
||||||
|
base_url = credentials.get("base_url", self.api_base)
|
||||||
|
base_url = base_url.removesuffix("/")
|
||||||
|
|
||||||
|
url = base_url + "/embeddings"
|
||||||
|
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||||
|
voyage_input_type = "null"
|
||||||
|
if input_type is not None:
|
||||||
|
voyage_input_type = input_type.value
|
||||||
|
data = {"model": model, "input": texts, "input_type": voyage_input_type}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, data=dumps(data))
|
||||||
|
except Exception as e:
|
||||||
|
raise InvokeConnectionError(str(e))
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
try:
|
||||||
|
resp = response.json()
|
||||||
|
msg = resp["detail"]
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise InvokeAuthorizationError(msg)
|
||||||
|
elif response.status_code == 429:
|
||||||
|
raise InvokeRateLimitError(msg)
|
||||||
|
elif response.status_code == 500:
|
||||||
|
raise InvokeServerUnavailableError(msg)
|
||||||
|
else:
|
||||||
|
raise InvokeBadRequestError(msg)
|
||||||
|
except JSONDecodeError as e:
|
||||||
|
raise InvokeServerUnavailableError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = response.json()
|
||||||
|
embeddings = resp["data"]
|
||||||
|
usage = resp["usage"]
|
||||||
|
except Exception as e:
|
||||||
|
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||||
|
|
||||||
|
result = TextEmbeddingResult(
|
||||||
|
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||||
|
except Exception as e:
|
||||||
|
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [InvokeConnectionError],
|
||||||
|
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||||
|
InvokeRateLimitError: [InvokeRateLimitError],
|
||||||
|
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||||
|
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
# get input price info
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
"""
|
||||||
|
generate custom model entities from credentials
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
model: voyage-3-lite
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 32000
|
||||||
|
pricing:
|
||||||
|
input: '0.00002'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: USD
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
model: voyage-3
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 32000
|
||||||
|
pricing:
|
||||||
|
input: '0.00006'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: USD
|
||||||
@ -0,0 +1,28 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VoyageProvider(ModelProvider):
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
|
||||||
|
|
||||||
|
# Use `voyage-3` model for validate,
|
||||||
|
# no matter what model you pass in, text completion model or chat model
|
||||||
|
model_instance.validate_credentials(model="voyage-3", credentials=credentials)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||||
|
raise ex
|
||||||
@ -0,0 +1,31 @@
|
|||||||
|
provider: voyage
|
||||||
|
label:
|
||||||
|
en_US: Voyage
|
||||||
|
description:
|
||||||
|
en_US: Embedding and Rerank Model Supported
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.svg
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.svg
|
||||||
|
background: "#EFFDFD"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: Get your API key from Voyage AI
|
||||||
|
zh_Hans: 从 Voyage 获取 API Key
|
||||||
|
url:
|
||||||
|
en_US: https://dash.voyageai.com/
|
||||||
|
supported_model_types:
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
configurate_methods:
|
||||||
|
- predefined-model
|
||||||
|
provider_credential_schema:
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
||||||
@ -1,142 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
|
||||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
|
||||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
|
||||||
"""
|
|
||||||
Model class for ZhipuAI text embedding model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
texts: list[str],
|
|
||||||
user: Optional[str] = None,
|
|
||||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
|
||||||
) -> TextEmbeddingResult:
|
|
||||||
"""
|
|
||||||
Invoke text embedding model
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param texts: texts to embed
|
|
||||||
:param user: unique user id
|
|
||||||
:param input_type: input type
|
|
||||||
:return: embeddings result
|
|
||||||
"""
|
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
|
||||||
client = ZhipuAI(api_key=credentials_kwargs["api_key"])
|
|
||||||
|
|
||||||
embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
|
|
||||||
|
|
||||||
return TextEmbeddingResult(
|
|
||||||
embeddings=embeddings,
|
|
||||||
usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens),
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
|
||||||
"""
|
|
||||||
Get number of tokens for given prompt messages
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param texts: texts to embed
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if len(texts) == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
total_num_tokens = 0
|
|
||||||
for text in texts:
|
|
||||||
total_num_tokens += self._get_num_tokens_by_gpt2(text)
|
|
||||||
|
|
||||||
return total_num_tokens
|
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
||||||
"""
|
|
||||||
Validate model credentials
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# transform credentials to kwargs for model instance
|
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
|
||||||
client = ZhipuAI(api_key=credentials_kwargs["api_key"])
|
|
||||||
|
|
||||||
# call embedding model
|
|
||||||
self.embed_documents(
|
|
||||||
model=model,
|
|
||||||
client=client,
|
|
||||||
texts=["ping"],
|
|
||||||
)
|
|
||||||
except Exception as ex:
|
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
|
||||||
|
|
||||||
def embed_documents(self, model: str, client: ZhipuAI, texts: list[str]) -> tuple[list[list[float]], int]:
|
|
||||||
"""Call out to ZhipuAI's embedding endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: The list of texts to embed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embeddings, one for each text.
|
|
||||||
"""
|
|
||||||
embeddings = []
|
|
||||||
embedding_used_tokens = 0
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
response = client.embeddings.create(model=model, input=text)
|
|
||||||
data = response.data[0]
|
|
||||||
embeddings.append(data.embedding)
|
|
||||||
embedding_used_tokens += response.usage.total_tokens
|
|
||||||
|
|
||||||
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> list[float]:
|
|
||||||
"""Call out to ZhipuAI's embedding endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to embed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Embeddings for the text.
|
|
||||||
"""
|
|
||||||
return self.embed_documents([text])[0]
|
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
|
||||||
"""
|
|
||||||
Calculate response usage
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param tokens: input tokens
|
|
||||||
:return: usage
|
|
||||||
"""
|
|
||||||
# get input price info
|
|
||||||
input_price_info = self.get_price(
|
|
||||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = EmbeddingUsage(
|
|
||||||
tokens=tokens,
|
|
||||||
total_tokens=tokens,
|
|
||||||
unit_price=input_price_info.unit_price,
|
|
||||||
price_unit=input_price_info.unit,
|
|
||||||
total_price=input_price_info.total_amount,
|
|
||||||
currency=input_price_info.currency,
|
|
||||||
latency=time.perf_counter() - self.started_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
return usage
|
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.voyage.voyage import VoyageProvider
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_provider_credentials():
|
||||||
|
provider = VoyageProvider()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||||
|
with patch("requests.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
|
||||||
|
"model": "voyage-3",
|
||||||
|
"usage": {"total_tokens": 1},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
|
||||||
@ -0,0 +1,92 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = VoyageRerankModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model="rerank-lite-1",
|
||||||
|
credentials={"api_key": "invalid_key"},
|
||||||
|
)
|
||||||
|
with patch("httpx.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"relevance_score": 0.546875,
|
||||||
|
"index": 0,
|
||||||
|
"document": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
|
||||||
|
"States Census, Carson City had a population of 55,274.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"relevance_score": 0.4765625,
|
||||||
|
"index": 1,
|
||||||
|
"document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the "
|
||||||
|
"Pacific Ocean that are a political division controlled by the United States. Its "
|
||||||
|
"capital is Saipan.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"model": "rerank-lite-1",
|
||||||
|
"usage": {"total_tokens": 96},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
model.validate_credentials(
|
||||||
|
model="rerank-lite-1",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = VoyageRerankModel()
|
||||||
|
with patch("httpx.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"relevance_score": 0.84375,
|
||||||
|
"index": 0,
|
||||||
|
"document": "Kasumi is a girl name of Japanese origin meaning mist.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"relevance_score": 0.4765625,
|
||||||
|
"index": 1,
|
||||||
|
"document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she "
|
||||||
|
"leads a team named PopiParty.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"model": "rerank-lite-1",
|
||||||
|
"usage": {"total_tokens": 59},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
result = model.invoke(
|
||||||
|
model="rerank-lite-1",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||||
|
},
|
||||||
|
query="Who is Kasumi?",
|
||||||
|
docs=[
|
||||||
|
"Kasumi is a girl name of Japanese origin meaning mist.",
|
||||||
|
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
|
||||||
|
"PopiParty.",
|
||||||
|
],
|
||||||
|
score_threshold=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, RerankResult)
|
||||||
|
assert len(result.docs) == 1
|
||||||
|
assert result.docs[0].index == 0
|
||||||
|
assert result.docs[0].score >= 0.5
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.voyage.text_embedding.text_embedding import VoyageTextEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = VoyageTextEmbeddingModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(model="voyage-3", credentials={"api_key": "invalid_key"})
|
||||||
|
with patch("requests.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
|
||||||
|
"model": "voyage-3",
|
||||||
|
"usage": {"total_tokens": 1},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
model.validate_credentials(model="voyage-3", credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = VoyageTextEmbeddingModel()
|
||||||
|
|
||||||
|
with patch("requests.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0},
|
||||||
|
{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 1},
|
||||||
|
],
|
||||||
|
"model": "voyage-3",
|
||||||
|
"usage": {"total_tokens": 2},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
result = model.invoke(
|
||||||
|
model="voyage-3",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||||
|
},
|
||||||
|
texts=["hello", "world"],
|
||||||
|
user="abc-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, TextEmbeddingResult)
|
||||||
|
assert len(result.embeddings) == 2
|
||||||
|
assert result.usage.total_tokens == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = VoyageTextEmbeddingModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="voyage-3",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||||
|
},
|
||||||
|
texts=["ping"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 1
|
||||||
Loading…
Reference in New Issue