Merge branch 'main' into fix/chore-fix
commit
196bfeaaf4
Binary file not shown.
|
After Width: | Height: | Size: 230 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 205 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 44 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 262 KiB |
@ -0,0 +1,173 @@
|
||||
## Predefined Model Integration
|
||||
|
||||
After completing the vendor integration, the next step is to integrate the models from the vendor.
|
||||
|
||||
First, we need to determine the type of model to be integrated and create the corresponding model type `module` under the respective vendor's directory.
|
||||
|
||||
Currently supported model types are:
|
||||
|
||||
- `llm` Text Generation Model
|
||||
- `text_embedding` Text Embedding Model
|
||||
- `rerank` Rerank Model
|
||||
- `speech2text` Speech-to-Text
|
||||
- `tts` Text-to-Speech
|
||||
- `moderation` Moderation
|
||||
|
||||
Continuing with `Anthropic` as an example, `Anthropic` only supports LLM, so create a `module` named `llm` under `model_providers.anthropic`.
|
||||
|
||||
For predefined models, we first need to create a YAML file named after the model under the `llm` `module`, such as `claude-2.1.yaml`.
|
||||
|
||||
### Prepare Model YAML
|
||||
|
||||
```yaml
|
||||
model: claude-2.1 # Model identifier
|
||||
# Display name of the model, which can be set to en_US English or zh_Hans Chinese. If zh_Hans is not set, it will default to en_US.
|
||||
# This can also be omitted, in which case the model identifier will be used as the label
|
||||
label:
|
||||
en_US: claude-2.1
|
||||
model_type: llm # Model type, claude-2.1 is an LLM
|
||||
features: # Supported features, agent-thought supports Agent reasoning, vision supports image understanding
|
||||
- agent-thought
|
||||
model_properties: # Model properties
|
||||
mode: chat # LLM mode, complete for text completion models, chat for conversation models
|
||||
context_size: 200000 # Maximum context size
|
||||
parameter_rules: # Parameter rules for the model call; only LLM requires this
|
||||
- name: temperature # Parameter variable name
|
||||
# Five default configuration templates are provided: temperature/top_p/max_tokens/presence_penalty/frequency_penalty
|
||||
# The template variable name can be set directly in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE
|
||||
# Additional configuration parameters will override the default configuration if set
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label: # Display name of the parameter
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int # Parameter type, supports float/int/string/boolean
|
||||
help: # Help information, describing the parameter's function
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false # Whether the parameter is mandatory; can be omitted
|
||||
- name: max_tokens_to_sample
|
||||
use_template: max_tokens
|
||||
default: 4096 # Default value of the parameter
|
||||
min: 1 # Minimum value of the parameter, applicable to float/int only
|
||||
max: 4096 # Maximum value of the parameter, applicable to float/int only
|
||||
pricing: # Pricing information
|
||||
input: '8.00' # Input unit price, i.e., prompt price
|
||||
output: '24.00' # Output unit price, i.e., response content price
|
||||
unit: '0.000001' # Price unit, meaning the above prices are per 100K
|
||||
currency: USD # Price currency
|
||||
```
|
||||
|
||||
It is recommended to prepare all model configurations before starting the implementation of the model code.
|
||||
|
||||
You can also refer to the YAML configuration information under the corresponding model type directories of other vendors in the `model_providers` directory. For the complete YAML rules, refer to: [Schema](schema.md#aimodelentity).
|
||||
|
||||
### Implement the Model Call Code
|
||||
|
||||
Next, create a Python file named `llm.py` under the `llm` `module` to write the implementation code.
|
||||
|
||||
Create an Anthropic LLM class named `AnthropicLargeLanguageModel` (or any other name), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods:
|
||||
|
||||
- LLM Call
|
||||
|
||||
Implement the core method for calling the LLM, supporting both streaming and synchronous responses.
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
```
|
||||
|
||||
Ensure to use two functions for returning data, one for synchronous returns and the other for streaming returns, because Python identifies functions containing the `yield` keyword as generator functions, fixing the return type to `Generator`. Thus, synchronous and streaming returns need to be implemented separately, as shown below (note that the example uses simplified parameters, for actual implementation follow the above parameter list):
|
||||
|
||||
```python
|
||||
def _invoke(self, stream: bool, **kwargs) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
if stream:
|
||||
return self._handle_stream_response(**kwargs)
|
||||
return self._handle_sync_response(**kwargs)
|
||||
|
||||
def _handle_stream_response(self, **kwargs) -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||||
return LLMResult(**response)
|
||||
```
|
||||
|
||||
- Pre-compute Input Tokens
|
||||
|
||||
If the model does not provide an interface to precompute tokens, return 0 directly.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- Validate Model Credentials
|
||||
|
||||
Similar to vendor credential validation, but specific to a single model.
|
||||
|
||||
```python
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
```
|
||||
|
||||
- Map Invoke Errors
|
||||
|
||||
When a model call fails, map it to a specific `InvokeError` type as required by Runtime, allowing Dify to handle different errors accordingly.
|
||||
|
||||
Runtime Errors:
|
||||
|
||||
- `InvokeConnectionError` Connection error
|
||||
|
||||
- `InvokeServerUnavailableError` Service provider unavailable
|
||||
- `InvokeRateLimitError` Rate limit reached
|
||||
- `InvokeAuthorizationError` Authorization failed
|
||||
- `InvokeBadRequestError` Parameter error
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
```
|
||||
|
||||
For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
|
||||
@ -1,238 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
EndpointConnectionError,
|
||||
NoRegionError,
|
||||
ServiceNotInRegionError,
|
||||
UnknownServiceError,
|
||||
)
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
model_prefix = model.split(".")[0]
|
||||
|
||||
if model_prefix == "amazon":
|
||||
for text in texts:
|
||||
body = {
|
||||
"inputText": text,
|
||||
}
|
||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||
embeddings.extend([response_body.get("embedding")])
|
||||
token_usage += response_body.get("inputTextTokenCount")
|
||||
logger.warning(f"Total Tokens: {token_usage}")
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
|
||||
)
|
||||
return result
|
||||
|
||||
if model_prefix == "cohere":
|
||||
input_type = "search_document" if len(texts) > 1 else "search_query"
|
||||
for text in texts:
|
||||
body = {
|
||||
"texts": [text],
|
||||
"input_type": input_type,
|
||||
}
|
||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||
embeddings.extend(response_body.get("embeddings"))
|
||||
token_usage += len(text)
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage),
|
||||
)
|
||||
return result
|
||||
|
||||
# others
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
return num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [],
|
||||
}
|
||||
|
||||
def _create_payload(
|
||||
self,
|
||||
model_prefix: str,
|
||||
texts: list[str],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
):
|
||||
"""
|
||||
Create payload for bedrock api call depending on model provider
|
||||
"""
|
||||
payload = {}
|
||||
|
||||
if model_prefix == "amazon":
|
||||
payload["inputText"] = texts
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
|
||||
"""
|
||||
Map client error to invoke error
|
||||
|
||||
:param error_code: error code
|
||||
:param error_msg: error message
|
||||
:return: invoke error
|
||||
"""
|
||||
|
||||
if error_code == "AccessDeniedException":
|
||||
return InvokeAuthorizationError(error_msg)
|
||||
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
|
||||
return InvokeBadRequestError(error_msg)
|
||||
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
|
||||
return InvokeRateLimitError(error_msg)
|
||||
elif error_code in {
|
||||
"ModelTimeoutException",
|
||||
"ModelErrorException",
|
||||
"InternalServerException",
|
||||
"ModelNotReadyException",
|
||||
}:
|
||||
return InvokeServerUnavailableError(error_msg)
|
||||
elif error_code == "ModelStreamErrorException":
|
||||
return InvokeConnectionError(error_msg)
|
||||
|
||||
return InvokeError(error_msg)
|
||||
|
||||
def _invoke_bedrock_embedding(
|
||||
self,
|
||||
model: str,
|
||||
bedrock_runtime,
|
||||
body: dict,
|
||||
):
|
||||
accept = "application/json"
|
||||
content_type = "application/json"
|
||||
try:
|
||||
response = bedrock_runtime.invoke_model(
|
||||
body=json.dumps(body), modelId=model, accept=accept, contentType=content_type
|
||||
)
|
||||
response_body = json.loads(response.get("body").read().decode("utf-8"))
|
||||
return response_body
|
||||
except ClientError as ex:
|
||||
error_code = ex.response["Error"]["Code"]
|
||||
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
||||
raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
||||
|
||||
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
||||
raise InvokeConnectionError(str(ex))
|
||||
|
||||
except UnknownServiceError as ex:
|
||||
raise InvokeServerUnavailableError(str(ex))
|
||||
|
||||
except Exception as ex:
|
||||
raise InvokeError(str(ex))
|
||||
@ -1,223 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an OpenAI API-compatible text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
# Prepare headers and payload for the request
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||
endpoint_url = "https://cloud.perfxlab.cn/v1/"
|
||||
else:
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
extra_model_kwargs["encoding_format"] = "float"
|
||||
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
# TODO: Optimize for better token estimation and chunking
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# Prepare the payload for the request
|
||||
payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs}
|
||||
|
||||
# Make the request to the OpenAI API
|
||||
response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
response_data = response.json()
|
||||
|
||||
# Extract embeddings and used tokens from the response
|
||||
embeddings_batch = [data["embedding"] for data in response_data["data"]]
|
||||
embedding_used_tokens = response_data["usage"]["total_tokens"]
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
api_key = credentials.get("api_key")
|
||||
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
||||
endpoint_url = "https://cloud.perfxlab.cn/v1/"
|
||||
else:
|
||||
endpoint_url = credentials.get("endpoint_url")
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, "embeddings")
|
||||
|
||||
payload = {"input": "ping", "model": model}
|
||||
|
||||
response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Credentials validation failed with status code {response.status_code}"
|
||||
)
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
|
||||
|
||||
if "model" not in json_result:
|
||||
raise CredentialsValidateFailedError("Credentials validation failed: invalid response")
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
@ -0,0 +1,21 @@
|
||||
<svg version="1.0" xmlns="http://www.w3.org/2000/svg" width="100.000000pt" height="19.000000pt" viewBox="0 0 300.000000 57.000000" preserveAspectRatio="xMidYMid meet"><g transform="translate(0.000000,57.000000) scale(0.100000,-0.100000)" fill="#000000" stroke="none"><path d="M2505 368 c-38 -84 -86 -188 -106 -230 l-38 -78 27 0 c24 0 30 7 55
|
||||
75 l28 75 100 0 100 0 25 -55 c13 -31 24 -64 24 -75 0 -17 7 -20 44 -20 l43 0
|
||||
-37 73 c-20 39 -68 143 -106 229 -38 87 -74 158 -80 158 -5 0 -41 -69 -79
|
||||
-152z m110 -30 c22 -51 41 -95 42 -98 2 -3 -36 -6 -83 -7 -76 -1 -85 0 -81 15
|
||||
12 40 72 182 77 182 3 0 24 -41 45 -92z"/><path d="M63 493 c19 -61 197 -438 209 -440 10 -2 147 282 216 449 2 4 -10 8
|
||||
-27 8 -23 0 -31 -5 -31 -17 0 -16 -142 -365 -146 -360 -8 11 -144 329 -149
|
||||
350 -6 23 -12 27 -42 27 -29 0 -34 -3 -30 -17z"/><path d="M2855 285 l0 -225 30 0 30 0 0 225 0 225 -30 0 -30 0 0 -225z"/><path d="M588 380 c-55 -30 -82 -74 -86 -145 -3 -50 0 -66 20 -95 39 -58 82
|
||||
-80 153 -80 68 0 110 21 149 73 32 43 30 150 -3 196 -47 66 -158 90 -233 51z
|
||||
m133 -16 c59 -30 89 -156 54 -224 -45 -87 -162 -78 -201 16 -18 44 -18 128 1
|
||||
164 28 55 90 73 146 44z"/><path d="M935 303 l76 -98 -7 -72 -6 -73 33 0 34 0 -3 78 -4 77 71 93 c65 85
|
||||
68 92 46 92 -15 0 -29 -9 -36 -22 -18 -33 -90 -128 -98 -128 -6 1 -67 85 -88
|
||||
122 -8 15 -24 23 -53 25 l-41 4 76 -98z"/><path d="M1257 230 c-82 -169 -83 -170 -57 -170 17 0 27 6 27 15 0 8 7 31 17
|
||||
52 l17 38 79 0 78 1 16 -34 c9 -18 16 -42 16 -52 0 -17 7 -20 41 -20 22 0 39
|
||||
3 37 8 -2 4 -39 80 -83 170 -43 89 -84 162 -92 162 -7 0 -50 -76 -96 -170z
|
||||
m90 -38 c-33 -2 -61 -1 -63 1 -2 2 10 34 26 71 l31 68 33 -68 33 -69 -60 -3z"/><path d="M1665 386 c-37 -16 -84 -63 -97 -96 -13 -35 -12 -104 2 -132 49 -94
|
||||
182 -134 280 -83 24 12 29 22 32 64 3 49 3 49 -30 53 l-33 4 3 -45 c4 -61 -5
|
||||
-71 -60 -71 -93 0 -142 57 -142 164 0 44 5 60 25 85 47 55 136 65 184 20 30
|
||||
-28 35 -20 11 19 -19 31 -22 32 -82 32 -35 -1 -76 -7 -93 -14z"/><path d="M1955 230 l0 -170 91 0 c76 0 93 3 98 16 4 9 5 18 4 20 -2 1 -31 -1
|
||||
-66 -5 -34 -4 -64 -5 -67 -3 -3 3 -5 36 -5 73 l0 68 55 -6 c49 -5 55 -4 55 13
|
||||
0 17 -6 19 -55 16 l-55 -4 0 61 0 61 64 0 c48 0 65 4 70 15 4 13 -10 15 -92
|
||||
15 l-97 0 0 -170z"/></g></svg>
|
||||
|
After Width: | Height: | Size: 2.2 KiB |
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="64px" height="64px" viewBox="0 0 64 64" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>voyage</title>
|
||||
<g id="voyage" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<rect id="矩形" fill="#333333" x="0" y="0" width="64" height="64" rx="12"></rect>
|
||||
<path d="M12.1128004,51.4376727 C13.8950799,45.8316747 30.5922254,11.1847688 31.7178757,11.0009656 C32.6559176,10.8171624 45.5070913,36.9172188 51.9795803,52.2647871 C52.1671887,52.6323936 51.0415384,53 49.4468672,53 C47.2893709,53 46.5389374,52.540492 46.5389374,51.4376727 C46.5389374,49.967247 33.2187427,17.8935861 32.8435259,18.3530942 C32.0930924,19.3640118 19.3357228,48.5887229 18.8667019,50.5186566 C18.3038768,52.6323936 17.7410516,53 14.926926,53 C12.2066045,53 11.7375836,52.7242952 12.1128004,51.4376727 Z" id="路径" fill="#FFFFFF" transform="translate(32, 32) scale(1, -1) translate(-32, -32)"></path>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.0 KiB |
@ -0,0 +1,4 @@
|
||||
model: rerank-1
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 8000
|
||||
@ -0,0 +1,4 @@
|
||||
model: rerank-lite-1
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 4000
|
||||
@ -0,0 +1,123 @@
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class VoyageRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Voyage rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n documents to return
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
base_url = credentials.get("base_url", "https://api.voyageai.com/v1")
|
||||
base_url = base_url.removesuffix("/")
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
base_url + "/rerank",
|
||||
json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True},
|
||||
headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results["data"]:
|
||||
rerank_document = RerankDocument(
|
||||
index=result["index"],
|
||||
text=result["document"],
|
||||
score=result["relevance_score"],
|
||||
)
|
||||
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))},
|
||||
)
|
||||
|
||||
return entity
|
||||
@ -0,0 +1,172 @@
|
||||
import time
|
||||
from json import JSONDecodeError, dumps
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
|
||||
class VoyageTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Voyage text embedding model.
|
||||
"""
|
||||
|
||||
api_base: str = "https://api.voyageai.com/v1"
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials["api_key"]
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError("api_key is required")
|
||||
|
||||
base_url = credentials.get("base_url", self.api_base)
|
||||
base_url = base_url.removesuffix("/")
|
||||
|
||||
url = base_url + "/embeddings"
|
||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||
voyage_input_type = "null"
|
||||
if input_type is not None:
|
||||
voyage_input_type = input_type.value
|
||||
data = {"model": model, "input": texts, "input_type": voyage_input_type}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
msg = resp["detail"]
|
||||
if response.status_code == 401:
|
||||
raise InvokeAuthorizationError(msg)
|
||||
elif response.status_code == 429:
|
||||
raise InvokeRateLimitError(msg)
|
||||
elif response.status_code == 500:
|
||||
raise InvokeServerUnavailableError(msg)
|
||||
else:
|
||||
raise InvokeBadRequestError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp["data"]
|
||||
usage = resp["usage"]
|
||||
except Exception as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
|
||||
}
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
)
|
||||
|
||||
return entity
|
||||
@ -0,0 +1,8 @@
|
||||
model: voyage-3-lite
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
pricing:
|
||||
input: '0.00002'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,8 @@
|
||||
model: voyage-3
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
pricing:
|
||||
input: '0.00006'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,28 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VoyageProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
|
||||
|
||||
# Use `voyage-3` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(model="voyage-3", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
@ -0,0 +1,31 @@
|
||||
provider: voyage
|
||||
label:
|
||||
en_US: Voyage
|
||||
description:
|
||||
en_US: Embedding and Rerank Model Supported
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#EFFDFD"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from Voyage AI
|
||||
zh_Hans: 从 Voyage 获取 API Key
|
||||
url:
|
||||
en_US: https://dash.voyageai.com/
|
||||
supported_model_types:
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
@ -1,142 +0,0 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
|
||||
|
||||
class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for ZhipuAI text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = ZhipuAI(api_key=credentials_kwargs["api_key"])
|
||||
|
||||
embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens),
|
||||
model=model,
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
total_num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = ZhipuAI(api_key=credentials_kwargs["api_key"])
|
||||
|
||||
# call embedding model
|
||||
self.embed_documents(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=["ping"],
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def embed_documents(self, model: str, client: ZhipuAI, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
embedding_used_tokens = 0
|
||||
|
||||
for text in texts:
|
||||
response = client.embeddings.create(model=model, input=text)
|
||||
data = response.data[0]
|
||||
embeddings.append(data.embedding)
|
||||
embedding_used_tokens += response.usage.total_tokens
|
||||
|
||||
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at,
|
||||
)
|
||||
|
||||
return usage
|
||||
@ -0,0 +1,25 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.voyage.voyage import VoyageProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = VoyageProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"object": "list",
|
||||
"data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
|
||||
"model": "voyage-3",
|
||||
"usage": {"total_tokens": 1},
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
|
||||
@ -0,0 +1,92 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = VoyageRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="rerank-lite-1",
|
||||
credentials={"api_key": "invalid_key"},
|
||||
)
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"relevance_score": 0.546875,
|
||||
"index": 0,
|
||||
"document": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
|
||||
"States Census, Carson City had a population of 55,274.",
|
||||
},
|
||||
{
|
||||
"relevance_score": 0.4765625,
|
||||
"index": 1,
|
||||
"document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the "
|
||||
"Pacific Ocean that are a political division controlled by the United States. Its "
|
||||
"capital is Saipan.",
|
||||
},
|
||||
],
|
||||
"model": "rerank-lite-1",
|
||||
"usage": {"total_tokens": 96},
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
model.validate_credentials(
|
||||
model="rerank-lite-1",
|
||||
credentials={
|
||||
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = VoyageRerankModel()
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"relevance_score": 0.84375,
|
||||
"index": 0,
|
||||
"document": "Kasumi is a girl name of Japanese origin meaning mist.",
|
||||
},
|
||||
{
|
||||
"relevance_score": 0.4765625,
|
||||
"index": 1,
|
||||
"document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she "
|
||||
"leads a team named PopiParty.",
|
||||
},
|
||||
],
|
||||
"model": "rerank-lite-1",
|
||||
"usage": {"total_tokens": 59},
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
result = model.invoke(
|
||||
model="rerank-lite-1",
|
||||
credentials={
|
||||
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||
},
|
||||
query="Who is Kasumi?",
|
||||
docs=[
|
||||
"Kasumi is a girl name of Japanese origin meaning mist.",
|
||||
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
|
||||
"PopiParty.",
|
||||
],
|
||||
score_threshold=0.5,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 1
|
||||
assert result.docs[0].index == 0
|
||||
assert result.docs[0].score >= 0.5
|
||||
@ -0,0 +1,70 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.voyage.text_embedding.text_embedding import VoyageTextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = VoyageTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(model="voyage-3", credentials={"api_key": "invalid_key"})
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"object": "list",
|
||||
"data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
|
||||
"model": "voyage-3",
|
||||
"usage": {"total_tokens": 1},
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
model.validate_credentials(model="voyage-3", credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = VoyageTextEmbeddingModel()
|
||||
|
||||
with patch("requests.post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0},
|
||||
{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 1},
|
||||
],
|
||||
"model": "voyage-3",
|
||||
"usage": {"total_tokens": 2},
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
result = model.invoke(
|
||||
model="voyage-3",
|
||||
credentials={
|
||||
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = VoyageTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="voyage-3",
|
||||
credentials={
|
||||
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||
},
|
||||
texts=["ping"],
|
||||
)
|
||||
|
||||
assert num_tokens == 1
|
||||
Loading…
Reference in New Issue