feat: move model request to plugin daemon

pull/9184/head
takatost 2 years ago
parent d9cced8419
commit 1c3213184e

@ -132,7 +132,7 @@ class ModelProviderIconApi(Resource):
def get(self, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon(
provider=provider, icon_type=icon_type, lang=lang
tenant_id=current_user.current_tenant_id, provider=provider, icon_type=icon_type, lang=lang
)
return send_file(io.BytesIO(icon), mimetype=mimetype)

@ -4,7 +4,8 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.provider_manager import ProviderManager
@ -67,14 +68,14 @@ class ModelConfigConverter:
stop = completion_params["stop"]
del completion_params["stop"]
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
# get model mode
model_mode = model_config.mode
if not model_mode:
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
model_mode = LLMMode.CHAT.value
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.")

@ -1,6 +1,6 @@
from core.app.app_config.entities import ModelConfigEntity
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
@ -50,6 +50,7 @@ class ModelConfigManager:
raise ValueError("model must be of object type")
# model.provider
model_provider_factory = ModelProviderFactory(tenant_id)
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:

@ -2,7 +2,7 @@ import datetime
import json
import logging
from collections import defaultdict
from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Optional
@ -18,16 +18,15 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from models.provider import (
LoadBalancingModelConfig,
@ -100,7 +99,9 @@ class ProviderConfiguration(BaseModel):
restrict_models = quota_configuration.restrict_models
copy_credentials = self.system_configuration.credentials.copy()
copy_credentials = (
self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
)
if restrict_models:
for restrict_model in restrict_models:
if (
@ -137,6 +138,9 @@ class ProviderConfiguration(BaseModel):
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
)
if not current_quota_configuration:
return SystemConfigurationStatus.UNSUPPORTED
return (
SystemConfigurationStatus.ACTIVE
if current_quota_configuration.is_valid
@ -172,7 +176,7 @@ class ProviderConfiguration(BaseModel):
else [],
)
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
@ -216,6 +220,7 @@ class ProviderConfiguration(BaseModel):
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
model_provider_factory = ModelProviderFactory(self.tenant_id)
credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
@ -243,13 +248,13 @@ class ProviderConfiguration(BaseModel):
provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(credentials),
is_valid=True,
)
provider_record = Provider()
provider_record.tenant_id = self.tenant_id
provider_record.provider_name = self.provider.provider
provider_record.provider_type = ProviderType.CUSTOM.value
provider_record.encrypted_config = json.dumps(credentials)
provider_record.is_valid = True
db.session.add(provider_record)
db.session.commit()
@ -324,7 +329,7 @@ class ProviderConfiguration(BaseModel):
def custom_model_credentials_validate(
self, model_type: ModelType, model: str, credentials: dict
) -> tuple[ProviderModel, dict]:
) -> tuple[ProviderModel | None, dict]:
"""
Validate custom model credentials.
@ -367,6 +372,7 @@ class ProviderConfiguration(BaseModel):
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
model_provider_factory = ModelProviderFactory(self.tenant_id)
credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
@ -397,14 +403,13 @@ class ProviderConfiguration(BaseModel):
provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
is_valid=True,
)
provider_model_record = ProviderModel()
provider_model_record.tenant_id = self.tenant_id
provider_model_record.provider_name = self.provider.provider
provider_model_record.model_name = model
provider_model_record.model_type = model_type.to_origin_model_type()
provider_model_record.encrypted_config = json.dumps(credentials)
provider_model_record.is_valid = True
db.session.add(provider_model_record)
db.session.commit()
@ -471,13 +476,12 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=True,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.enabled = True
db.session.add(model_setting)
db.session.commit()
@ -506,13 +510,12 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=False,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.enabled = False
db.session.add(model_setting)
db.session.commit()
@ -573,13 +576,12 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=True,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.load_balancing_enabled = True
db.session.add(model_setting)
db.session.commit()
@ -608,25 +610,17 @@ class ProviderConfiguration(BaseModel):
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=False,
)
model_setting = ProviderModelSetting()
model_setting.tenant_id = self.tenant_id
model_setting.provider_name = self.provider.provider
model_setting.model_type = model_type.to_origin_model_type()
model_setting.model_name = model
model_setting.load_balancing_enabled = False
db.session.add(model_setting)
db.session.commit()
return model_setting
def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
:return:
"""
return model_provider_factory.get_provider_instance(self.provider.provider)
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
"""
Get current model type instance.
@ -634,11 +628,19 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
# Get provider instance
provider_instance = self.get_provider_instance()
model_provider_factory = ModelProviderFactory(self.tenant_id)
# Get model instance of LLM
return provider_instance.get_model_instance(model_type)
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get model schema
"""
model_provider_factory = ModelProviderFactory(self.tenant_id)
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
"""
@ -665,11 +667,10 @@ class ProviderConfiguration(BaseModel):
if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value
else:
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value,
)
preferred_model_provider = TenantPreferredModelProvider()
preferred_model_provider.tenant_id = self.tenant_id
preferred_model_provider.provider_name = self.provider.provider
preferred_model_provider.preferred_provider_type = provider_type.value
db.session.add(preferred_model_provider)
db.session.commit()
@ -734,13 +735,14 @@ class ProviderConfiguration(BaseModel):
:param only_active: only active models
:return:
"""
provider_instance = self.get_provider_instance()
model_provider_factory = ModelProviderFactory(self.tenant_id)
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types = []
if model_type:
model_types.append(model_type)
else:
model_types = provider_instance.get_provider_schema().supported_model_types
model_types = provider_schema.supported_model_types
# Group model settings by model type and model
model_setting_map = defaultdict(dict)
@ -749,11 +751,11 @@ class ProviderConfiguration(BaseModel):
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
)
if only_active:
@ -764,23 +766,26 @@ class ProviderConfiguration(BaseModel):
def _get_system_provider_models(
self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
"""
Get system provider models.
:param model_types: model types
:param provider_instance: provider instance
:param provider_schema: provider schema
:param model_setting_map: model setting map
:return:
"""
provider_models = []
for model_type in model_types:
for m in provider_instance.models(model_type):
for m in provider_schema.models:
if m.model_type != model_type:
continue
status = ModelStatus.ACTIVE
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
if m.model in model_setting_map:
model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
@ -801,7 +806,7 @@ class ProviderConfiguration(BaseModel):
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
for configurate_method in provider_schema.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
should_use_custom_model = False
@ -822,14 +827,20 @@ class ProviderConfiguration(BaseModel):
]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
copy_credentials = (
self.system_configuration.credentials.copy()
if self.system_configuration.credentials
else {}
)
if restrict_model.base_model_name:
copy_credentials["base_model_name"] = restrict_model.base_model_name
try:
custom_model_schema = provider_instance.get_model_instance(
restrict_model.model_type
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
custom_model_schema = self.get_model_schema(
model_type=restrict_model.model_type,
model=restrict_model.model,
credentials=copy_credentials,
)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
continue
@ -875,15 +886,15 @@ class ProviderConfiguration(BaseModel):
def _get_custom_provider_models(
self,
model_types: list[ModelType],
provider_instance: ModelProvider,
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
:param model_types: model types
:param provider_instance: provider instance
:param provider_schema: provider schema
:param model_setting_map: model setting map
:return:
"""
@ -897,8 +908,10 @@ class ProviderConfiguration(BaseModel):
if model_type not in self.provider.supported_model_types:
continue
models = provider_instance.models(model_type)
for m in models:
for m in provider_schema.models:
if m.model_type != model_type:
continue
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
load_balancing_enabled = False
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
@ -930,10 +943,10 @@ class ProviderConfiguration(BaseModel):
continue
try:
custom_model_schema = provider_instance.get_model_instance(
model_configuration.model_type
).get_customizable_model_schema_from_credentials(
model_configuration.model, model_configuration.credentials
custom_model_schema = self.get_model_schema(
model_type=model_configuration.model_type,
model=model_configuration.model,
credentials=model_configuration.credentials,
)
except Exception as ex:
logger.warning(f"get custom model schema failed, {ex}")
@ -1043,7 +1056,7 @@ class ProviderConfigurations(BaseModel):
return iter(self.configurations)
def values(self) -> Iterator[ProviderConfiguration]:
return self.configurations.values()
return iter(self.configurations.values())
def get(self, key, default=None):
return self.configurations.get(key, default)
@ -1055,7 +1068,6 @@ class ProviderModelBundle(BaseModel):
"""
configuration: ProviderConfiguration
provider_instance: ModelProvider
model_type_instance: AIModel
# pydantic configs

@ -23,6 +23,9 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
hosting_openai_config = hosting_configuration.provider_map["openai"]
if hosting_openai_config.credentials is None:
return False
# 2000 text per chunk
length = 2000
text_chunks = [text[i : i + length] for i in range(0, len(text), length)]

@ -5,7 +5,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType, ProviderModel
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
class ConfigurateMethod(Enum):
@ -101,7 +101,7 @@ class SimpleProviderEntity(BaseModel):
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: Sequence[ModelType]
models: list[ProviderModel] = []
models: list[AIModelEntity] = []
class ProviderHelpEntity(BaseModel):
@ -127,7 +127,7 @@ class ProviderEntity(BaseModel):
help: Optional[ProviderHelpEntity] = None
supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod]
models: list[ProviderModel] = []
models: list[AIModelEntity] = []
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None

@ -1,10 +1,9 @@
import decimal
import os
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Optional
from pydantic import ConfigDict
from pydantic import ConfigDict, Field
from core.helper.position_helper import get_position_map, sort_by_position_map
from core.model_runtime.entities.common_entities import I18nObject
@ -20,34 +19,26 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.utils.yaml_utils import load_yaml_file
class AIModel(ABC):
class AIModel:
"""
Base class for all models.
"""
model_type: ModelType
model_schemas: Optional[list[AIModelEntity]] = None
started_at: float = 0
tenant_id: str = Field(description="Tenant ID")
model_type: ModelType = Field(description="Model type")
plugin_id: str = Field(description="Plugin ID")
provider_name: str = Field(description="Provider")
plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider")
started_at: float = Field(description="Invoke start time", default=0)
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@abstractmethod
def validate_credentials(self, model: str, credentials: Mapping) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
raise NotImplementedError
@property
@abstractmethod
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
@ -66,20 +57,18 @@ class AIModel(ABC):
:param error: model invoke error
:return: unified error
"""
provider_name = self.__class__.__module__.split(".")[-3]
for invoke_error, model_errors in self._invoke_error_mapping.items():
if isinstance(error, tuple(model_errors)):
if invoke_error == InvokeAuthorizationError:
return invoke_error(
description=(
f"[{provider_name}] Incorrect model credentials provided, please check and try again."
f"[{self.provider_name}] Incorrect model credentials provided, please check and try again."
)
)
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
return invoke_error(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
return InvokeError(description=f"[{provider_name}] Error: {str(error)}")
return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}")
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
"""

@ -1,32 +1,25 @@
import logging
import os
import re
import time
from abc import abstractmethod
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import Optional, Union
from pydantic import ConfigDict
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.callbacks.logging_callback import LoggingCallback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
PriceType,
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.manager.model import PluginModelManager
logger = logging.getLogger(__name__)
@ -71,8 +64,6 @@ class LargeLanguageModel(AIModel):
if model_parameters is None:
model_parameters = {}
model_parameters = self._validate_and_filter_model_parameters(model, model_parameters, credentials)
self.started_at = time.perf_counter()
callbacks = callbacks or []
@ -94,20 +85,43 @@ class LargeLanguageModel(AIModel):
)
try:
if "response_format" in model_parameters:
result = self._code_block_mode_wrapper(
plugin_model_manager = PluginModelManager()
result = plugin_model_manager.invoke_llm(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model=model,
credentials=credentials,
model_parameters=model_parameters,
prompt_messages=prompt_messages,
tools=tools,
stop=stop,
stream=stream,
)
if not stream:
content = ""
content_list = []
usage = LLMUsage.empty_usage()
system_fingerprint = None
for chunk in result:
if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content)
usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint
break
result = LLMResult(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
message=AssistantPromptMessage(content=content or content_list),
usage=usage,
system_fingerprint=system_fingerprint,
)
else:
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
except Exception as e:
self._trigger_invoke_error_callbacks(
model=model,
@ -122,6 +136,7 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks,
)
# TODO
raise self._transform_invoke_error(e)
if stream and isinstance(result, Generator):
@ -153,244 +168,6 @@ class LargeLanguageModel(AIModel):
return result
def _code_block_mode_wrapper(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper, ensure the response is a code block with output markdown quote
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
""" # noqa: E501
code_block = model_parameters.get("response_format", "")
if not code_block:
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts.replace("{{instructions}}", str(prompt_messages[0].content))
)
else:
# insert the system message
prompt_messages.insert(
0,
SystemPromptMessage(
content=block_prompts.replace("{{instructions}}", f"Please output a valid {code_block} object.")
),
)
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last text message
if isinstance(prompt_messages[-1].content, str):
prompt_messages[-1].content += f"\n```{code_block}\n"
elif isinstance(prompt_messages[-1].content, list):
for i in range(len(prompt_messages[-1].content) - 1, -1, -1):
if prompt_messages[-1].content[i].type == PromptMessageContentType.TEXT:
prompt_messages[-1].content[i].data += f"\n```{code_block}\n"
break
else:
# append a user message
prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n"))
response = self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
if isinstance(response, Generator):
first_chunk = next(response)
def new_generator():
yield first_chunk
yield from response
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
return self._code_block_mode_stream_processor_with_backtick(
model=model, prompt_messages=prompt_messages, input_generator=new_generator()
)
else:
return self._code_block_mode_stream_processor(
model=model, prompt_messages=prompt_messages, input_generator=new_generator()
)
return response
def _code_block_mode_stream_processor(
self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote
:param model: model name
:param prompt_messages: prompt messages
:param input_generator: input generator
:return: output generator
"""
state = "normal"
backtick_count = 0
for piece in input_generator:
if piece.delta.message.content:
content = piece.delta.message.content
piece.delta.message.content = ""
yield piece
piece = content
else:
yield piece
continue
new_piece: str = ""
for char in piece:
char = str(char)
if state == "normal":
if char == "`":
state = "in_backticks"
backtick_count = 1
else:
new_piece += char
elif state == "in_backticks":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "skip_content"
backtick_count = 0
else:
new_piece += "`" * backtick_count + char
state = "normal"
backtick_count = 0
elif state == "skip_content":
if char.isspace():
state = "normal"
if new_piece:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=new_piece, tool_calls=[]),
),
)
def _code_block_mode_stream_processor_with_backtick(
self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote.
This version skips the language identifier that follows the opening triple backticks.
:param model: model name
:param prompt_messages: prompt messages
:param input_generator: input generator
:return: output generator
"""
state = "search_start"
backtick_count = 0
for piece in input_generator:
if piece.delta.message.content:
content = piece.delta.message.content
# Reset content to ensure we're only processing and yielding the relevant parts
piece.delta.message.content = ""
# Yield a piece with cleared content before processing it to maintain the generator structure
yield piece
piece = content
else:
# Yield pieces without content directly
yield piece
continue
if state == "done":
continue
new_piece: str = ""
for char in piece:
if state == "search_start":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "skip_language"
backtick_count = 0
else:
backtick_count = 0
elif state == "skip_language":
# Skip everything until the first newline, marking the end of the language identifier
if char == "\n":
state = "in_code_block"
elif state == "in_code_block":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "done"
break
else:
if backtick_count > 0:
# If backticks were counted but we're still collecting content, it was a false start
new_piece += "`" * backtick_count
backtick_count = 0
new_piece += str(char)
elif state == "done":
break
if new_piece:
# Only yield content collected within the code block
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=new_piece, tool_calls=[]),
),
)
def _invoke_result_generator(
self,
model: str,
@ -462,34 +239,6 @@ if you are not sure about the structure.
callbacks=callbacks,
)
@abstractmethod
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(
self,
model: str,
@ -506,41 +255,18 @@ if you are not sure about the structure.
:param tools: tools for tool calling
:return:
"""
raise NotImplementedError
def enforce_stop_tokens(self, text: str, stop: list[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRule]:
"""
Get parameter rules
:param model: model name
:param credentials: model credentials
:return: parameter rules
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema:
return model_schema.parameter_rules
return []
def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
"""
Get model mode
:param model: model name
:param credentials: model credentials
:return: model mode
"""
model_schema = self.get_model_schema(model, credentials)
mode = LLMMode.CHAT
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE])
return mode
plugin_model_manager = PluginModelManager()
return plugin_model_manager.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
tools=tools,
)
def _calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
@ -772,98 +498,3 @@ if you are not sure about the structure.
raise e
else:
logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}")
def _validate_and_filter_model_parameters(self, model: str, model_parameters: dict, credentials: dict) -> dict:
"""
Validate model parameters
:param model: model name
:param model_parameters: model parameters
:param credentials: model credentials
:return:
"""
parameter_rules = self.get_parameter_rules(model, credentials)
# validate model parameters
filtered_model_parameters = {}
for parameter_rule in parameter_rules:
parameter_name = parameter_rule.name
parameter_value = model_parameters.get(parameter_name)
if parameter_value is None:
if parameter_rule.use_template and parameter_rule.use_template in model_parameters:
# if parameter value is None, use template value variable name instead
parameter_value = model_parameters[parameter_rule.use_template]
else:
if parameter_rule.required:
if parameter_rule.default is not None:
filtered_model_parameters[parameter_name] = parameter_rule.default
continue
else:
raise ValueError(f"Model Parameter {parameter_name} is required.")
else:
continue
# validate parameter value type
if parameter_rule.type == ParameterType.INT:
if not isinstance(parameter_value, int):
raise ValueError(f"Model Parameter {parameter_name} should be int.")
# validate parameter value range
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
raise ValueError(
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
)
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
raise ValueError(
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
)
elif parameter_rule.type == ParameterType.FLOAT:
if not isinstance(parameter_value, float | int):
raise ValueError(f"Model Parameter {parameter_name} should be float.")
# validate parameter value precision
if parameter_rule.precision is not None:
if parameter_rule.precision == 0:
if parameter_value != int(parameter_value):
raise ValueError(f"Model Parameter {parameter_name} should be int.")
else:
if parameter_value != round(parameter_value, parameter_rule.precision):
raise ValueError(
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision}"
f" decimal places."
)
# validate parameter value range
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
raise ValueError(
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
)
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
raise ValueError(
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
)
elif parameter_rule.type == ParameterType.BOOLEAN:
if not isinstance(parameter_value, bool):
raise ValueError(f"Model Parameter {parameter_name} should be bool.")
elif parameter_rule.type == ParameterType.STRING:
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be string.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
elif parameter_rule.type == ParameterType.TEXT:
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be text.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
else:
raise ValueError(f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported.")
filtered_model_parameters[parameter_name] = parameter_value
return filtered_model_parameters

@ -8,6 +8,7 @@ from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.plugin.manager.model import PluginModelManager
class TextEmbeddingModel(AIModel):
@ -66,7 +67,6 @@ class TextEmbeddingModel(AIModel):
"""
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -76,7 +76,17 @@ class TextEmbeddingModel(AIModel):
:param texts: texts to embed
:return:
"""
raise NotImplementedError
plugin_model_manager = PluginModelManager()
return plugin_model_manager.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials,
texts=texts,
)
def _get_context_size(self, model: str, credentials: dict) -> int:
"""

@ -1,3 +0,0 @@
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
model_provider_factory = ModelProviderFactory()

@ -1,78 +0,0 @@
<svg width="90" height="20" viewBox="0 0 90 20" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_8587_60274)">
<mask id="mask0_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M89.375 4.99805H0V14.998H89.375V4.99805Z" fill="white"/>
</mask>
<g mask="url(#mask0_8587_60274)">
<mask id="mask1_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99609H89.375V14.9961H0V4.99609Z" fill="white"/>
</mask>
<g mask="url(#mask1_8587_60274)">
<mask id="mask2_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99414H89.375V14.9941H0V4.99414Z" fill="white"/>
</mask>
<g mask="url(#mask2_8587_60274)">
<mask id="mask3_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask3_8587_60274)">
<path d="M18.1273 11.9244L13.7773 5.15625H11.4297V14.825H13.4321V8.05688L17.7821 14.825H20.1297V5.15625H18.1273V11.9244Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask4_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask4_8587_60274)">
<path d="M21.7969 7.02094H25.0423V14.825H27.1139V7.02094H30.3594V5.15625H21.7969V7.02094Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask5_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask5_8587_60274)">
<path d="M38.6442 9.00994H34.0871V5.15625H32.0156V14.825H34.0871V10.8746H38.6442V14.825H40.7156V5.15625H38.6442V9.00994Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask6_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask6_8587_60274)">
<path d="M45.3376 7.02094H47.893C48.9152 7.02094 49.4539 7.39387 49.4539 8.09831C49.4539 8.80275 48.9152 9.17569 47.893 9.17569H45.3376V7.02094ZM51.5259 8.09831C51.5259 6.27506 50.186 5.15625 47.9897 5.15625H43.2656V14.825H45.3376V11.0404H47.6443L49.7164 14.825H52.0094L49.715 10.7521C50.8666 10.3094 51.5259 9.37721 51.5259 8.09831Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask7_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask7_8587_60274)">
<path d="M57.8732 13.0565C56.2438 13.0565 55.2496 11.8963 55.2496 10.004C55.2496 8.08416 56.2438 6.92394 57.8732 6.92394C59.4887 6.92394 60.4691 8.08416 60.4691 10.004C60.4691 11.8963 59.4887 13.0565 57.8732 13.0565ZM57.8732 4.99023C55.0839 4.99023 53.1094 7.06206 53.1094 10.004C53.1094 12.9184 55.0839 14.9902 57.8732 14.9902C60.6486 14.9902 62.6094 12.9184 62.6094 10.004C62.6094 7.06206 60.6486 4.99023 57.8732 4.99023Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask8_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask8_8587_60274)">
<path d="M69.1794 9.45194H66.6233V7.02094H69.1794C70.2019 7.02094 70.7407 7.43532 70.7407 8.23644C70.7407 9.03756 70.2019 9.45194 69.1794 9.45194ZM69.2762 5.15625H64.5508V14.825H66.6233V11.3166H69.2762C71.473 11.3166 72.8133 10.1564 72.8133 8.23644C72.8133 6.3165 71.473 5.15625 69.2762 5.15625Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask9_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask9_8587_60274)">
<path d="M86.8413 11.5786C86.4823 12.5179 85.7642 13.0565 84.7837 13.0565C83.1542 13.0565 82.16 11.8963 82.16 10.004C82.16 8.08416 83.1542 6.92394 84.7837 6.92394C85.7642 6.92394 86.4823 7.46261 86.8413 8.40183H89.0369C88.4984 6.33002 86.8827 4.99023 84.7837 4.99023C81.9942 4.99023 80.0195 7.06206 80.0195 10.004C80.0195 12.9184 81.9942 14.9902 84.7837 14.9902C86.8965 14.9902 88.5122 13.6366 89.0508 11.5786H86.8413Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask10_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask10_8587_60274)">
<path d="M73.6484 5.15625L77.5033 14.825H79.6172L75.7624 5.15625H73.6484Z" fill="black" fill-opacity="0.92"/>
</g>
<mask id="mask11_8587_60274" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="4" width="90" height="11">
<path d="M0 4.99219H89.375V14.9922H0V4.99219Z" fill="white"/>
</mask>
<g mask="url(#mask11_8587_60274)">
<path d="M3.64038 10.9989L4.95938 7.60106L6.27838 10.9989H3.64038ZM3.85422 5.15625L0 14.825H2.15505L2.9433 12.7946H6.97558L7.76371 14.825H9.91875L6.06453 5.15625H3.85422Z" fill="black" fill-opacity="0.92"/>
</g>
</g>
</g>
</g>
</g>
<defs>
<clipPath id="clip0_8587_60274">
<rect width="89.375" height="10" fill="white" transform="translate(0 5)"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 5.3 KiB

@ -1,4 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="24" height="24" rx="6" fill="#CA9F7B"/>
<path d="M15.3843 6.43481H12.9687L17.3739 17.5652H19.7896L15.3843 6.43481ZM8.40522 6.43481L4 17.5652H6.4633L7.36417 15.2279H11.9729L12.8737 17.5652H15.337L10.9318 6.43481H8.40522ZM8.16104 13.1607L9.66852 9.24907L11.176 13.1607H8.16104Z" fill="#191918"/>
</svg>

Before

Width:  |  Height:  |  Size: 410 B

@ -1,28 +0,0 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class AnthropicProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `claude-3-opus-20240229` model for validate,
model_instance.validate_credentials(model="claude-3-opus-20240229", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

@ -1,39 +0,0 @@
provider: anthropic
label:
en_US: Anthropic
description:
en_US: Anthropics powerful models, such as Claude 3.
zh_Hans: Anthropic 的强大模型,例如 Claude 3。
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#F0F0EB"
help:
title:
en_US: Get your API Key from Anthropic
zh_Hans: 从 Anthropic 获取 API Key
url:
en_US: https://console.anthropic.com/account/keys
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: anthropic_api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: anthropic_api_url
label:
en_US: API URL
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API URL
en_US: Enter your API URL

@ -1,8 +0,0 @@
- claude-3-5-sonnet-20240620
- claude-3-haiku-20240307
- claude-3-opus-20240229
- claude-3-sonnet-20240229
- claude-2.1
- claude-instant-1.2
- claude-2
- claude-instant-1

@ -1,36 +0,0 @@
model: claude-2.1
label:
en_US: claude-2.1
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '8.00'
output: '24.00'
unit: '0.000001'
currency: USD

@ -1,37 +0,0 @@
model: claude-2
label:
en_US: claude-2
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '8.00'
output: '24.00'
unit: '0.000001'
currency: USD
deprecated: true

@ -1,39 +0,0 @@
model: claude-3-5-sonnet-20240620
label:
en_US: claude-3-5-sonnet-20240620
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
pricing:
input: '3.00'
output: '15.00'
unit: '0.000001'
currency: USD

@ -1,39 +0,0 @@
model: claude-3-haiku-20240307
label:
en_US: claude-3-haiku-20240307
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '0.25'
output: '1.25'
unit: '0.000001'
currency: USD

@ -1,39 +0,0 @@
model: claude-3-opus-20240229
label:
en_US: claude-3-opus-20240229
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '15.00'
output: '75.00'
unit: '0.000001'
currency: USD

@ -1,39 +0,0 @@
model: claude-3-sonnet-20240229
label:
en_US: claude-3-sonnet-20240229
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '3.00'
output: '15.00'
unit: '0.000001'
currency: USD

@ -1,36 +0,0 @@
model: claude-instant-1.2
label:
en_US: claude-instant-1.2
model_type: llm
features: [ ]
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '1.63'
output: '5.51'
unit: '0.000001'
currency: USD
deprecated: true

@ -1,36 +0,0 @@
model: claude-instant-1
label:
en_US: claude-instant-1
model_type: llm
features: [ ]
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '1.63'
output: '5.51'
unit: '0.000001'
currency: USD
deprecated: true

@ -1,624 +0,0 @@
import base64
import io
import json
from collections.abc import Generator
from typing import Optional, Union, cast
import anthropic
import requests
from anthropic import Anthropic, Stream
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStopEvent,
MessageStreamEvent,
completion_create_params,
)
from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout
from PIL import Image
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
""" # noqa: E501
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# invoke model
return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# transform model parameters from completion api of anthropic to chat api
if "max_tokens_to_sample" in model_parameters:
model_parameters["max_tokens"] = model_parameters.pop("max_tokens_to_sample")
# init model client
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs["stop_sequences"] = stop
if user:
extra_model_kwargs["metadata"] = completion_create_params.Metadata(user_id=user)
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
if system:
extra_model_kwargs["system"] = system
# Add the new header for claude-3-5-sonnet-20240620 model
extra_headers = {}
if model == "claude-3-5-sonnet-20240620":
if model_parameters.get("max_tokens") > 4096:
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
if tools:
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
response = client.beta.tools.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
extra_headers=extra_headers,
**model_parameters,
**extra_model_kwargs,
)
else:
# chat model
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
extra_headers=extra_headers,
**model_parameters,
**extra_model_kwargs,
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _code_block_mode_wrapper(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if model_parameters.get("response_format"):
stop = stop or []
# chat model
self._transform_chat_json_prompts(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters["response_format"],
)
model_parameters.pop("response_format")
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict:
return {"name": tool.name, "description": tool.description, "input_schema": tool.parameters}
def _transform_chat_json_prompts(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
response_format: str = "JSON",
) -> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
"{{block}}", response_format
)
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
else:
# insert the system message
prompt_messages.insert(
0,
SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace(
"{{instructions}}", f"Please output a valid {response_format} object."
).replace("{{block}}", response_format)
),
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
client = Anthropic(api_key="")
tokens = client.count_tokens(prompt)
tool_call_inner_prompts_tokens_map = {
"claude-3-opus-20240229": 395,
"claude-3-haiku-20240307": 264,
"claude-3-sonnet-20240229": 159,
}
if model in tool_call_inner_prompts_tokens_map and tools:
tokens += tool_call_inner_prompts_tokens_map[model]
return tokens
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=[
UserPromptMessage(content="ping"),
],
model_parameters={
"temperature": 0,
"max_tokens": 20,
},
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: Union[Message, ToolsBetaMessage],
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content="", tool_calls=[])
for content in response.content:
if content.type == "text":
assistant_prompt_message.content += content.text
elif content.type == "tool_use":
tool_call = AssistantPromptMessage.ToolCall(
id=content.id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=content.name, arguments=json.dumps(content.input)
),
)
assistant_prompt_message.tool_calls.append(tool_call)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
response = LLMResult(
model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage
)
return response
def _handle_chat_generate_stream_response(
self, model: str, credentials: dict, response: Stream[MessageStreamEvent], prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm chat stream response
:param model: model name
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ""
return_model = None
input_tokens = 0
output_tokens = 0
finish_reason = None
index = 0
tool_calls: list[AssistantPromptMessage.ToolCall] = []
for chunk in response:
if isinstance(chunk, MessageStartEvent):
if hasattr(chunk, "content_block"):
content_block = chunk.content_block
if isinstance(content_block, dict):
if content_block.get("type") == "tool_use":
tool_call = AssistantPromptMessage.ToolCall(
id=content_block.get("id"),
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=content_block.get("name"), arguments=""
),
)
tool_calls.append(tool_call)
elif hasattr(chunk, "delta"):
delta = chunk.delta
if isinstance(delta, dict) and len(tool_calls) > 0:
if delta.get("type") == "input_json_delta":
tool_calls[-1].function.arguments += delta.get("partial_json", "")
elif chunk.message:
return_model = chunk.message.model
input_tokens = chunk.message.usage.input_tokens
elif isinstance(chunk, MessageDeltaEvent):
output_tokens = chunk.usage.output_tokens
finish_reason = chunk.delta.stop_reason
elif isinstance(chunk, MessageStopEvent):
# transform usage
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
# transform empty tool call arguments to {}
for tool_call in tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = "{}"
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index + 1,
message=AssistantPromptMessage(content="", tool_calls=tool_calls),
finish_reason=finish_reason,
usage=usage,
),
)
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text or ""
full_assistant_content += chunk_text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=chunk_text)
index = chunk.index
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk.index,
message=assistant_prompt_message,
),
)
def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance
:param credentials:
:return:
"""
credentials_kwargs = {
"api_key": credentials["anthropic_api_key"],
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"max_retries": 1,
}
if credentials.get("anthropic_api_url"):
credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip("/")
credentials_kwargs["base_url"] = credentials["anthropic_api_url"]
return credentials_kwargs
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
"""
Convert prompt messages to dict list and system
"""
system = ""
first_loop = True
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
message.content = message.content.strip()
if first_loop:
system = message.content
first_loop = False
else:
system += "\n"
system += message.content
prompt_message_dicts = []
for message in prompt_messages:
if not isinstance(message, SystemPromptMessage):
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
prompt_message_dicts.append(message_dict)
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(
f"Failed to fetch image data from url {message_content.data}, {ex}"
)
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp"
)
sub_message_dict = {
"type": "image",
"source": {"type": "base64", "media_type": mime_type, "data": base64_data},
}
sub_messages.append(sub_message_dict)
prompt_message_dicts.append({"role": "user", "content": sub_messages})
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
content = []
if message.tool_calls:
for tool_call in message.tool_calls:
content.append(
{
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": json.loads(tool_call.function.arguments),
}
)
if message.content:
content.append({"type": "text", "text": message.content})
if prompt_message_dicts[-1]["role"] == "assistant":
prompt_message_dicts[-1]["content"].extend(content)
else:
prompt_message_dicts.append({"role": "assistant", "content": content})
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}
],
}
prompt_message_dicts.append(message_dict)
else:
raise ValueError(f"Got unknown type {message}")
return system, prompt_message_dicts
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.
:param message: PromptMessage to convert.
:return: String representation of the message.
"""
human_prompt = "\n\nHuman:"
ai_prompt = "\n\nAssistant:"
content = message.content
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
if not isinstance(message.content, list):
message_text = f"{ai_prompt} {content}"
else:
message_text = ""
for sub_message in message.content:
if sub_message.type == PromptMessageContentType.TEXT:
message_text += f"{human_prompt} {sub_message.data}"
elif sub_message.type == PromptMessageContentType.IMAGE:
message_text += f"{human_prompt} [IMAGE]"
elif isinstance(message, AssistantPromptMessage):
if not isinstance(message.content, list):
message_text = f"{ai_prompt} {content}"
else:
message_text = ""
for sub_message in message.content:
if sub_message.type == PromptMessageContentType.TEXT:
message_text += f"{ai_prompt} {sub_message.data}"
elif sub_message.type == PromptMessageContentType.IMAGE:
message_text += f"{ai_prompt} [IMAGE]"
elif isinstance(message, SystemPromptMessage):
message_text = content
elif isinstance(message, ToolPromptMessage):
message_text = f"{human_prompt} {message.content}"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
if not messages:
return ""
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AssistantPromptMessage):
messages.append(AssistantPromptMessage(content=""))
text = "".join(self._convert_one_message_to_text(message) for message in messages)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [anthropic.APIConnectionError, anthropic.APITimeoutError],
InvokeServerUnavailableError: [anthropic.InternalServerError],
InvokeRateLimitError: [anthropic.RateLimitError],
InvokeAuthorizationError: [anthropic.AuthenticationError, anthropic.PermissionDeniedError],
InvokeBadRequestError: [
anthropic.BadRequestError,
anthropic.NotFoundError,
anthropic.UnprocessableEntityError,
anthropic.APIError,
],
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 10 KiB

@ -1,17 +0,0 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class AzureAIStudioProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass

@ -1,65 +0,0 @@
provider: azure_ai_studio
label:
zh_Hans: Azure AI Studio
en_US: Azure AI Studio
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
description:
en_US: Azure AI Studio
zh_Hans: Azure AI Studio
background: "#93c5fd"
help:
title:
en_US: How to deploy customized model on Azure AI Studio
zh_Hans: 如何在Azure AI Studio上的私有化部署的模型
url:
en_US: https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models
zh_Hans: https://learn.microsoft.com/zh-cn/azure/ai-studio/how-to/deploy-models
supported_model_types:
- llm
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: endpoint
label:
en_US: Azure AI Studio Endpoint
type: text-input
required: true
placeholder:
zh_Hans: 请输入你的Azure AI Studio推理端点
en_US: 'Enter your API Endpoint, eg: https://example.com'
- variable: api_key
required: true
label:
en_US: API Key
zh_Hans: API Key
type: secret-input
placeholder:
en_US: Enter your Azure AI Studio API Key
zh_Hans: 在此输入您的 Azure AI Studio API Key
show_on:
- variable: __model_type
value: llm
- variable: jwt_token
required: true
label:
en_US: JWT Token
zh_Hans: JWT令牌
type: secret-input
placeholder:
en_US: Enter your Azure AI Studio JWT Token
zh_Hans: 在此输入您的 Azure AI Studio 推理 API Key
show_on:
- variable: __model_type
value: rerank

@ -1,334 +0,0 @@
import logging
from collections.abc import Generator
from typing import Any, Optional, Union
from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import StreamingChatCompletionsUpdate
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import (
ClientAuthenticationError,
DecodeError,
DeserializationError,
HttpResponseError,
ResourceExistsError,
ResourceModifiedError,
ResourceNotFoundError,
ResourceNotModifiedError,
SerializationError,
ServiceRequestError,
ServiceResponseError,
)
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
I18nObject,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
"""
Model class for Azure AI Studio large language model.
"""
client: Any = None
from azure.ai.inference.models import StreamingChatCompletionsUpdate
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
if not self.client:
endpoint = credentials.get("endpoint")
api_key = credentials.get("api_key")
self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages]
payload = {
"messages": messages,
"max_tokens": model_parameters.get("max_tokens", 4096),
"temperature": model_parameters.get("temperature", 0),
"top_p": model_parameters.get("top_p", 1),
"stream": stream,
}
if stop:
payload["stop"] = stop
if tools:
payload["tools"] = [tool.model_dump() for tool in tools]
try:
response = self.client.complete(**payload)
if stream:
return self._handle_stream_response(response, model, prompt_messages)
else:
return self._handle_non_stream_response(response, model, prompt_messages, credentials)
except Exception as e:
raise self._transform_invoke_error(e)
def _handle_stream_response(self, response, model: str, prompt_messages: list[PromptMessage]) -> Generator:
for chunk in response:
if isinstance(chunk, StreamingChatCompletionsUpdate):
if chunk.choices:
delta = chunk.choices[0].delta
if delta.content:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=delta.content, tool_calls=[]),
),
)
def _handle_non_stream_response(
self, response, model: str, prompt_messages: list[PromptMessage], credentials: dict
) -> LLMResult:
assistant_text = response.choices[0].message.content
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
usage = self._calc_response_usage(
model, credentials, response.usage.prompt_tokens, response.usage.completion_tokens
)
result = LLMResult(model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage)
if hasattr(response, "system_fingerprint"):
result.system_fingerprint = response.system_fingerprint
return result
def _invoke_result_generator(
self,
model: str,
result: Generator,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Generator:
"""
Invoke result generator
:param result: result generator
:return: result generator
"""
callbacks = callbacks or []
prompt_message = AssistantPromptMessage(content="")
usage = None
system_fingerprint = None
real_model = model
try:
for chunk in result:
if isinstance(chunk, dict):
content = chunk["choices"][0]["message"]["content"]
usage = chunk["usage"]
chunk = LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=content, tool_calls=[]),
),
system_fingerprint=chunk.get("system_fingerprint"),
)
yield chunk
self._trigger_new_chunk_callbacks(
chunk=chunk,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
prompt_message.content += chunk.delta.message.content
real_model = chunk.model
if hasattr(chunk.delta, "usage"):
usage = chunk.delta.usage
if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint
except Exception as e:
raise self._transform_invoke_error(e)
self._trigger_after_invoke_callbacks(
model=model,
result=LLMResult(
model=real_model,
prompt_messages=prompt_messages,
message=prompt_message,
usage=usage or LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint,
),
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
# Implement token counting logic here
# Might need to use a tokenizer specific to the Azure AI Studio model
return 0
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
endpoint = credentials.get("endpoint")
api_key = credentials.get("api_key")
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
client.get_model_info()
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
ServiceRequestError,
],
InvokeServerUnavailableError: [
ServiceResponseError,
],
InvokeAuthorizationError: [
ClientAuthenticationError,
],
InvokeBadRequestError: [
HttpResponseError,
DecodeError,
ResourceExistsError,
ResourceNotFoundError,
ResourceModifiedError,
ResourceNotModifiedError,
SerializationError,
DeserializationError,
],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Used to define customizable model schema
"""
rules = [
ParameterRule(
name="temperature",
type=ParameterType.FLOAT,
use_template="temperature",
label=I18nObject(zh_Hans="温度", en_US="Temperature"),
),
ParameterRule(
name="top_p",
type=ParameterType.FLOAT,
use_template="top_p",
label=I18nObject(zh_Hans="Top P", en_US="Top P"),
),
ParameterRule(
name="max_tokens",
type=ParameterType.INT,
use_template="max_tokens",
min=1,
default=512,
label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
),
]
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
features=[],
model_properties={},
parameter_rules=rules,
)
return entity

@ -1,164 +0,0 @@
import json
import logging
import os
import ssl
import urllib.request
from typing import Optional
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
logger = logging.getLogger(__name__)
class AzureRerankModel(RerankModel):
"""
Model class for Azure AI Studio rerank model.
"""
def _allow_self_signed_https(self, allowed):
# bypass the server certificate verification on client side
if allowed and not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(ssl, "_create_unverified_context", None):
ssl._create_default_https_context = ssl._create_unverified_context
def _azure_rerank(self, query_input: str, docs: list[str], endpoint: str, api_key: str):
# self._allow_self_signed_https(True) # Enable if using self-signed certificate
data = {"inputs": query_input, "docs": docs}
body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
req = urllib.request.Request(endpoint, body, headers)
try:
with urllib.request.urlopen(req) as response:
result = response.read()
return json.loads(result)
except urllib.error.HTTPError as error:
logger.error(f"The request failed with status code: {error.code}")
logger.error(error.info())
logger.error(error.read().decode("utf8", "ignore"))
raise
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
if len(docs) == 0:
return RerankResult(model=model, docs=[])
endpoint = credentials.get("endpoint")
api_key = credentials.get("jwt_token")
if not endpoint or not api_key:
raise ValueError("Azure endpoint and API key must be provided in credentials")
result = self._azure_rerank(query, docs, endpoint, api_key)
logger.info(f"Azure rerank result: {result}")
rerank_documents = []
for idx, (doc, score_dict) in enumerate(zip(docs, result)):
score = score_dict["score"]
rerank_document = RerankDocument(index=idx, text=doc, score=score)
if score_threshold is None or score >= score_threshold:
rerank_documents.append(rerank_document)
rerank_documents.sort(key=lambda x: x.score, reverse=True)
if top_n:
rerank_documents = rerank_documents[:top_n]
return RerankResult(model=model, docs=rerank_documents)
except Exception as e:
logger.exception(f"Exception in Azure rerank: {e}")
raise
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [urllib.error.URLError],
InvokeServerUnavailableError: [urllib.error.HTTPError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={},
parameter_rules=[],
)
return entity

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

@ -1,8 +0,0 @@
<svg width="21" height="22" viewBox="0 0 21 22" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="Microsoft">
<rect id="Rectangle 1010" y="0.5" width="10" height="10" fill="#EF4F21"/>
<rect id="Rectangle 1012" y="11.5" width="10" height="10" fill="#03A4EE"/>
<rect id="Rectangle 1011" x="11" y="0.5" width="10" height="10" fill="#7EB903"/>
<rect id="Rectangle 1013" x="11" y="11.5" width="10" height="10" fill="#FBB604"/>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 439 B

@ -1,42 +0,0 @@
import openai
from httpx import Timeout
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPENAI_API_VERSION
class _CommonAzureOpenAI:
@staticmethod
def _to_credential_kwargs(credentials: dict) -> dict:
api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION)
credentials_kwargs = {
"api_key": credentials["openai_api_key"],
"azure_endpoint": credentials["openai_api_base"],
"api_version": api_version,
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"max_retries": 1,
}
return credentials_kwargs
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
InvokeServerUnavailableError: [openai.InternalServerError],
InvokeRateLimitError: [openai.RateLimitError],
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
InvokeBadRequestError: [
openai.BadRequestError,
openai.NotFoundError,
openai.UnprocessableEntityError,
openai.APIError,
],
}

@ -1,10 +0,0 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class AzureOpenAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

@ -1,227 +0,0 @@
provider: azure_openai
label:
en_US: Azure OpenAI Service Model
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.png
background: "#E3F0FF"
help:
title:
en_US: Get your API key from Azure
zh_Hans: 从 Azure 获取 API Key
url:
en_US: https://azure.microsoft.com/en-us/products/ai-services/openai-service
supported_model_types:
- llm
- text-embedding
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Deployment Name
zh_Hans: 部署名称
placeholder:
en_US: Enter your Deployment Name here, matching the Azure deployment name.
zh_Hans: 在此输入您的部署名称,与 Azure 部署名称匹配。
credential_form_schemas:
- variable: openai_api_base
label:
en_US: API Endpoint URL
zh_Hans: API 域名
type: text-input
required: true
placeholder:
zh_Hans: '在此输入您的 API 域名https://example.com/xxx'
en_US: 'Enter your API Endpoint, eg: https://example.com/xxx'
- variable: openai_api_key
label:
en_US: API Key
zh_Hans: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API key here
- variable: openai_api_version
label:
zh_Hans: API 版本
en_US: API Version
type: select
required: true
options:
- label:
en_US: 2024-08-01-preview
value: 2024-08-01-preview
- label:
en_US: 2024-07-01-preview
value: 2024-07-01-preview
- label:
en_US: 2024-05-01-preview
value: 2024-05-01-preview
- label:
en_US: 2024-04-01-preview
value: 2024-04-01-preview
- label:
en_US: 2024-03-01-preview
value: 2024-03-01-preview
- label:
en_US: 2024-02-15-preview
value: 2024-02-15-preview
- label:
en_US: 2023-12-01-preview
value: 2023-12-01-preview
- label:
en_US: '2024-02-01'
value: '2024-02-01'
- label:
en_US: '2024-06-01'
value: '2024-06-01'
placeholder:
zh_Hans: 在此选择您的 API 版本
en_US: Select your API Version here
- variable: base_model_name
label:
en_US: Base Model
zh_Hans: 基础模型
type: select
required: true
options:
- label:
en_US: gpt-35-turbo
value: gpt-35-turbo
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-0125
value: gpt-35-turbo-0125
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-16k
value: gpt-35-turbo-16k
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4
value: gpt-4
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-32k
value: gpt-4-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-mini
value: gpt-4o-mini
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-mini-2024-07-18
value: gpt-4o-mini-2024-07-18
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o
value: gpt-4o
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-05-13
value: gpt-4o-2024-05-13
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-08-06
value: gpt-4o-2024-08-06
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo
value: gpt-4-turbo
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo-2024-04-09
value: gpt-4-turbo-2024-04-09
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-0125-preview
value: gpt-4-0125-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-1106-preview
value: gpt-4-1106-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-vision-preview
value: gpt-4-vision-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-instruct
value: gpt-35-turbo-instruct
show_on:
- variable: __model_type
value: llm
- label:
en_US: text-embedding-ada-002
value: text-embedding-ada-002
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-small
value: text-embedding-3-small
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-large
value: text-embedding-3-large
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: whisper-1
value: whisper-1
show_on:
- variable: __model_type
value: speech2text
- label:
en_US: tts-1
value: tts-1
show_on:
- variable: __model_type
value: tts
- label:
en_US: tts-1-hd
value: tts-1-hd
show_on:
- variable: __model_type
value: tts
placeholder:
zh_Hans: 在此输入您的模型版本
en_US: Enter your model version

@ -1,665 +0,0 @@
import copy
import json
import logging
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
import tiktoken
from openai import AzureOpenAI, Stream
from openai.types import Completion
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageFunction,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
return self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
else:
# text completion model
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not model_entity:
raise ValueError(f"Base Model Name {base_model_name} is invalid")
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value:
# chat model
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
else:
# text completion model, do not support tool calling
content = prompt_messages[0].content
assert isinstance(content, str)
return self._num_tokens_from_string(credentials, content)
def validate_credentials(self, model: str, credentials: dict) -> None:
if "openai_api_base" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
if "openai_api_key" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required")
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise CredentialsValidateFailedError("Base Model Name is required")
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=0,
max_tokens=20,
stream=False,
)
else:
# text completion model
client.completions.create(
prompt="ping",
model=model,
temperature=0,
max_tokens=20,
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
return ai_model_entity.entity if ai_model_entity else None
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
extra_model_kwargs = {}
if stop:
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs["user"] = user
# text completion model
response = client.completions.create(
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage]
):
assistant_text = response.choices[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.system_fingerprint,
)
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage]
) -> Generator:
full_text = ""
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
continue
# transform assistant message to prompt message
text = delta.text or ""
assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text
if delta.finish_reason is not None:
# calculate num tokens
if chunk.usage:
# transform usage
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
else:
# calculate num tokens
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, full_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage,
),
)
else:
yield LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
),
)
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_schema":
json_schema = model_parameters.get("json_schema")
if not json_schema:
raise ValueError("Must define JSON Schema when the response format is json_schema")
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not correct json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:
model_parameters["response_format"] = {"type": response_format}
extra_model_kwargs = {}
if tools:
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
# extra_model_kwargs['functions'] = [{
# "name": tool.name,
# "description": tool.description,
# "parameters": tool.parameters
# } for tool in tools]
if stop:
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs["user"] = user
# chat model
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
response = client.chat.completions.create(
messages=messages,
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs,
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
):
assistant_message = response.choices[0].message
assistant_message_tool_calls = assistant_message.tool_calls
# extract tool calls from response
tool_calls = []
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=response.model or model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.system_fingerprint,
)
return result
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
):
index = 0
full_assistant_content = ""
real_model = model
system_fingerprint = None
completion = ""
tool_calls = []
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
# NOTE: For fix https://github.com/langgenius/dify/issues/5790
if delta.delta is None:
continue
# extract tool calls from response
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
if delta.finish_reason is None and not delta.delta.content:
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
full_assistant_content += delta.delta.content or ""
real_model = chunk.model
system_fingerprint = chunk.system_fingerprint
completion += delta.delta.content or ""
yield LLMResultChunk(
model=real_model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
),
)
index += 1
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=real_model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage
),
)
@staticmethod
def _update_tool_calls(
tool_calls: list[AssistantPromptMessage.ToolCall],
tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]],
) -> None:
if tool_calls_response:
for response_tool_call in tool_calls_response:
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id, type=response_tool_call.type, function=function
)
tool_calls.append(tool_call)
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
index = response_tool_call.index
if index < len(tool_calls):
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
if response_tool_call.function:
tool_calls[index].function.name = (
response_tool_call.function.name or tool_calls[index].function.name
)
tool_calls[index].function.arguments += response_tool_call.function.arguments or ""
else:
assert response_tool_call.id is not None
assert response_tool_call.type is not None
assert response_tool_call.function is not None
assert response_tool_call.function.name is not None
assert response_tool_call.function.arguments is not None
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id, type=response_tool_call.type, function=function
)
tool_calls.append(tool_call)
@staticmethod
def _convert_prompt_message_to_dict(message: PromptMessage):
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
assert message.content is not None
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
# message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "tool",
"name": message.name,
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _num_tokens_from_string(
self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None
) -> int:
try:
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(text))
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
def _num_tokens_from_messages(
self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
model = credentials["base_model_name"]
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
if model.startswith("gpt-35-turbo-0301"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
if key == "tool_calls":
for tool_call in value:
assert isinstance(tool_call, dict)
for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key))
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(t_key))
num_tokens += len(encoding.encode(t_value))
else:
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
@staticmethod
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
num_tokens = 0
for tool in tools:
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode("function"))
# calculate num tokens for function object
num_tokens += len(encoding.encode("name"))
num_tokens += len(encoding.encode(tool.name))
num_tokens += len(encoding.encode("description"))
num_tokens += len(encoding.encode(tool.description))
parameters = tool.parameters
num_tokens += len(encoding.encode("parameters"))
if "title" in parameters:
num_tokens += len(encoding.encode("title"))
num_tokens += len(encoding.encode(parameters["title"]))
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode(parameters["type"]))
if "properties" in parameters:
num_tokens += len(encoding.encode("properties"))
for key, value in parameters["properties"].items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if "required" in parameters:
num_tokens += len(encoding.encode("required"))
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str):
for ai_model_entity in LLM_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy

@ -1,79 +0,0 @@
import copy
from typing import IO, Optional
from openai import AzureOpenAI
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
"""
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
return self._speech2text_invoke(model, credentials, file)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, "rb") as audio_file:
self._speech2text_invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:return: text for given audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# init model client
client = AzureOpenAI(**credentials_kwargs)
response = client.audio.transcriptions.create(model=model, file=file)
return response.text
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
return None

@ -1,128 +0,0 @@
import concurrent.futures
import copy
from typing import Optional
from openai import AzureOpenAI
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
"""
def _invoke(
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
) -> any:
"""
_invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:param user: unique user id
:return: text translated to audio file
"""
if not voice or voice not in [
d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials)
]:
voice = self._get_model_default_voice(model, credentials)
return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
validate credentials text2speech model
:param model: model name
:param credentials: model credentials
:return: text translated to audio file
"""
try:
self._tts_invoke_streaming(
model=model,
credentials=credentials,
content_text="Hello Dify!",
voice=self._get_model_default_voice(model, credentials),
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
try:
# doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
# max length is 4096 characters, there is 3500 limit for each request
max_length = 3500
if len(content_text) > max_length:
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [
executor.submit(
client.audio.speech.with_streaming_response.create,
model=model,
response_format="mp3",
input=sentences[i],
voice=voice,
)
for i in range(len(sentences))
]
for future in futures:
yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
else:
response = client.audio.speech.with_streaming_response.create(
model=model, voice=voice, response_format="mp3", input=content_text.strip()
)
yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _process_sentence(self, sentence: str, model: str, voice, credentials: dict):
"""
_tts_invoke openai text2speech model api
:param model: model name
:param credentials: model credentials
:param voice: model timbre
:param sentence: text content to be translated
:return: text translated to audio file
"""
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
if isinstance(response.read(), bytes):
return response.read()
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None:
for ai_model_entity in TTS_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
return None

@ -1,19 +0,0 @@
<svg width="130" height="24" viewBox="0 0 130 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M9.58154 1.7793H6.52779L4.34655 6.20409V17.7335L1.91602 22.2206H7.21333L9.58154 17.7335V1.7793ZM11.5761 1.7793H16.8111V22.2206H11.5761V1.7793ZM23.9166 1.7793H18.6816V6.01712H23.9166V1.7793ZM23.9166 7.38818H18.6816V22.2206H23.9166V7.38818Z" fill="url(#paint0_radial_11622_96091)"/>
<path d="M129.722 6.83203V18H127.482V6.83203H129.722Z" fill="#FF6A34"/>
<path d="M123.196 15.872H118.748L118.012 18H115.66L119.676 6.81604H122.284L126.3 18H123.932L123.196 15.872ZM122.588 14.08L120.972 9.40804L119.356 14.08H122.588Z" fill="#FF6A34"/>
<path d="M110.962 18H108.722L103.65 10.336V18H101.41V6.81598H103.65L108.722 14.496V6.81598H110.962V18Z" fill="#FF6A34"/>
<path d="M97.1258 15.872H92.6778L91.9418 18H89.5898L93.6058 6.81604H96.2138L100.23 18H97.8618L97.1258 15.872ZM96.5178 14.08L94.9018 9.40804L93.2858 14.08H96.5178Z" fill="#FF6A34"/>
<path d="M81.6482 6.83203V13.744C81.6482 14.5014 81.8455 15.0827 82.2402 15.488C82.6349 15.8827 83.1895 16.08 83.9042 16.08C84.6295 16.08 85.1895 15.8827 85.5842 15.488C85.9789 15.0827 86.1762 14.5014 86.1762 13.744V6.83203H88.4322V13.728C88.4322 14.6774 88.2242 15.4827 87.8082 16.144C87.4029 16.7947 86.8535 17.2854 86.1602 17.616C85.4775 17.9467 84.7149 18.112 83.8722 18.112C83.0402 18.112 82.2829 17.9467 81.6002 17.616C80.9282 17.2854 80.3949 16.7947 80.0002 16.144C79.6055 15.4827 79.4082 14.6774 79.4082 13.728V6.83203H81.6482Z" fill="#FF6A34"/>
<path d="M77.557 6.83203V18H75.317V13.248H70.533V18H68.293V6.83203H70.533V11.424H75.317V6.83203H77.557Z" fill="#FF6A34"/>
<path d="M55.7871 12.4C55.7871 11.3013 56.0324 10.32 56.5231 9.45599C57.0244 8.58132 57.7018 7.90399 58.5551 7.42399C59.4191 6.93332 60.3844 6.68799 61.4511 6.68799C62.6991 6.68799 63.7924 7.00799 64.7311 7.64799C65.6698 8.28799 66.3258 9.17332 66.6991 10.304H64.1231C63.8671 9.77065 63.5044 9.37065 63.0351 9.10399C62.5764 8.83732 62.0431 8.70399 61.4351 8.70399C60.7844 8.70399 60.2031 8.85865 59.6911 9.16799C59.1898 9.46665 58.7951 9.89332 58.5071 10.448C58.2298 11.0027 58.0911 11.6533 58.0911 12.4C58.0911 13.136 58.2298 13.7867 58.5071 14.352C58.7951 14.9067 59.1898 15.3387 59.6911 15.648C60.2031 15.9467 60.7844 16.096 61.4351 16.096C62.0431 16.096 62.5764 15.9627 63.0351 15.696C63.5044 15.4187 63.8671 15.0133 64.1231 14.48H66.6991C66.3258 15.6213 65.6698 16.512 64.7311 17.152C63.8031 17.7813 62.7098 18.096 61.4511 18.096C60.3844 18.096 59.4191 17.856 58.5551 17.376C57.7018 16.8853 57.0244 16.208 56.5231 15.344C56.0324 14.48 55.7871 13.4987 55.7871 12.4Z" fill="#FF6A34"/>
<path d="M54.4373 6.83203V18H52.1973V6.83203H54.4373Z" fill="#FF6A34"/>
<path d="M47.913 15.872H43.465L42.729 18H40.377L44.393 6.81598H47.001L51.017 18H48.649L47.913 15.872ZM47.305 14.08L45.689 9.40798L44.073 14.08H47.305Z" fill="#FF6A34"/>
<path d="M37.4395 12.272C38.0688 12.3893 38.5862 12.704 38.9915 13.216C39.3968 13.728 39.5995 14.3146 39.5995 14.976C39.5995 15.5733 39.4502 16.1013 39.1515 16.56C38.8635 17.008 38.4422 17.36 37.8875 17.616C37.3328 17.872 36.6768 18 35.9195 18H31.1035V6.83197H35.7115C36.4688 6.83197 37.1195 6.95464 37.6635 7.19997C38.2182 7.4453 38.6342 7.78664 38.9115 8.22397C39.1995 8.6613 39.3435 9.1573 39.3435 9.71197C39.3435 10.3626 39.1675 10.9066 38.8155 11.344C38.4742 11.7813 38.0155 12.0906 37.4395 12.272ZM33.3435 11.44H35.3915C35.9248 11.44 36.3355 11.3226 36.6235 11.088C36.9115 10.8426 37.0555 10.496 37.0555 10.048C37.0555 9.59997 36.9115 9.2533 36.6235 9.00797C36.3355 8.76264 35.9248 8.63997 35.3915 8.63997H33.3435V11.44ZM35.5995 16.176C36.1435 16.176 36.5648 16.048 36.8635 15.792C37.1728 15.536 37.3275 15.1733 37.3275 14.704C37.3275 14.224 37.1675 13.8506 36.8475 13.584C36.5275 13.3066 36.0955 13.168 35.5515 13.168H33.3435V16.176H35.5995Z" fill="#FF6A34"/>
<defs>
<radialGradient id="paint0_radial_11622_96091" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(6.5 5.5) rotate(45) scale(20.5061 22.0704)">
<stop stop-color="#FEBD3F"/>
<stop offset="0.77608" stop-color="#FF6933"/>
</radialGradient>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 4.0 KiB

@ -1,11 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="Baichuan">
<path id="Union" fill-rule="evenodd" clip-rule="evenodd" d="M8.58154 1.7793H5.52779L3.34655 6.20409V17.7335L0.916016 22.2206H6.21333L8.58154 17.7335V1.7793ZM10.5761 1.7793H15.8111V22.2206H10.5761V1.7793ZM22.9166 1.7793H17.6816V6.01712H22.9166V1.7793ZM22.9166 7.38818H17.6816V22.2206H22.9166V7.38818Z" fill="url(#paint0_radial_11622_96084)"/>
</g>
<defs>
<radialGradient id="paint0_radial_11622_96084" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(5.5 5.5) rotate(45) scale(20.5061 22.0704)">
<stop stop-color="#FEBD3F"/>
<stop offset="0.77608" stop-color="#FF6933"/>
</radialGradient>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 748 B

@ -1,28 +0,0 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class BaichuanProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `baichuan2-turbo` model for validate,
model_instance.validate_credentials(model="baichuan2-turbo", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

@ -1,29 +0,0 @@
provider: baichuan
label:
en_US: Baichuan
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#FFF6F2"
help:
title:
en_US: Get your API Key from BAICHUAN AI
zh_Hans: 从百川智能获取您的 API Key
url:
en_US: https://www.baichuan-ai.com
supported_model_types:
- llm
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key

@ -1,46 +0,0 @@
model: baichuan2-53b
label:
en_US: Baichuan2-53B
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 1000
min: 1
max: 4000
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
default: 1
min: 1
max: 2
- name: with_search_enhance
label:
zh_Hans: 搜索增强
en_US: Search Enhance
type: boolean
help:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
deprecated: true

@ -1,46 +0,0 @@
model: baichuan2-turbo-192k
label:
en_US: Baichuan2-Turbo-192K
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 192000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8000
min: 1
max: 192000
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
default: 1
min: 1
max: 2
- name: with_search_enhance
label:
zh_Hans: 搜索增强
en_US: Search Enhance
type: boolean
help:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
deprecated: true

@ -1,41 +0,0 @@
model: baichuan2-turbo
label:
en_US: Baichuan2-Turbo
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.3
- name: top_p
use_template: top_p
default: 0.85
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
min: 0
max: 20
default: 5
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
default: 2048
- name: with_search_enhance
label:
zh_Hans: 搜索增强
en_US: Search Enhance
type: boolean
help:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false

@ -1,53 +0,0 @@
model: baichuan3-turbo-128k
label:
en_US: Baichuan3-Turbo-128k
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.3
- name: top_p
use_template: top_p
default: 0.85
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
min: 0
max: 20
default: 5
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
default: 2048
- name: res_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
- name: with_search_enhance
label:
zh_Hans: 搜索增强
en_US: Search Enhance
type: boolean
help:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false

@ -1,53 +0,0 @@
model: baichuan3-turbo
label:
en_US: Baichuan3-Turbo
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.3
- name: top_p
use_template: top_p
default: 0.85
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
min: 0
max: 20
default: 5
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
default: 2048
- name: res_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
- name: with_search_enhance
label:
zh_Hans: 搜索增强
en_US: Search Enhance
type: boolean
help:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false

@ -1,53 +0,0 @@
model: baichuan4
label:
en_US: Baichuan4
model_type: llm
features:
- agent-thought
- multi-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.3
- name: top_p
use_template: top_p
default: 0.85
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
min: 0
max: 20
default: 5
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
default: 2048
- name: res_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
- name: with_search_enhance
label:
zh_Hans: 搜索增强
en_US: Search Enhance
type: boolean
help:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false

@ -1,21 +0,0 @@
import re
class BaichuanTokenizer:
@classmethod
def count_chinese_characters(cls, text: str) -> int:
return len(re.findall(r"[\u4e00-\u9fa5]", text))
@classmethod
def count_english_vocabularies(cls, text: str) -> int:
# remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc.
text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
# count the number of words not characters
return len(text.split())
@classmethod
def _get_num_tokens(cls, text: str) -> int:
# tokens = number of Chinese characters + number of English words * 1.3
# (for estimation only, subject to actual return)
# https://platform.baichuan-ai.com/docs/text-Embedding
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)

@ -1,144 +0,0 @@
import json
from collections.abc import Iterator
from typing import Any, Optional, Union
from requests import post
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalanceError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
class BaichuanModel:
api_key: str
def __init__(self, api_key: str) -> None:
self.api_key = api_key
@property
def _model_mapping(self) -> dict:
return {
"baichuan2-turbo": "Baichuan2-Turbo",
"baichuan3-turbo": "Baichuan3-Turbo",
"baichuan3-turbo-128k": "Baichuan3-Turbo-128k",
"baichuan4": "Baichuan4",
}
@property
def request_headers(self) -> dict[str, Any]:
return {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
}
def _build_parameters(
self,
model: str,
stream: bool,
messages: list[dict],
parameters: dict[str, Any],
tools: Optional[list[PromptMessageTool]] = None,
) -> dict[str, Any]:
if model in self._model_mapping:
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
# we need to rename it to res_format to get its value
if parameters.get("res_format") == "json_object":
parameters["response_format"] = {"type": "json_object"}
if tools or parameters.get("with_search_enhance") is True:
parameters["tools"] = []
# with_search_enhance is deprecated, use web_search instead
if parameters.get("with_search_enhance") is True:
parameters["tools"].append(
{
"type": "web_search",
"web_search": {"enable": True},
}
)
if tools:
for tool in tools:
parameters["tools"].append(
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
)
# turbo api accepts flat parameters
return {
"model": self._model_mapping.get(model),
"stream": stream,
"messages": messages,
**parameters,
}
else:
raise BadRequestError(f"Unknown model: {model}")
def generate(
self,
model: str,
stream: bool,
messages: list[dict],
parameters: dict[str, Any],
timeout: int,
tools: Optional[list[PromptMessageTool]] = None,
) -> Union[Iterator, dict]:
if model in self._model_mapping:
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
else:
raise BadRequestError(f"Unknown model: {model}")
data = self._build_parameters(model, stream, messages, parameters, tools)
try:
response = post(
url=api_base,
headers=self.request_headers,
data=json.dumps(data),
timeout=timeout,
stream=stream,
)
except Exception as e:
raise InternalServerError(f"Failed to invoke model: {e}")
if response.status_code != 200:
try:
resp = response.json()
# try to parse error message
err = resp["error"]["type"]
msg = resp["error"]["message"]
except Exception as e:
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
if err == "invalid_api_key":
raise InvalidAPIKeyError(msg)
elif err == "insufficient_quota":
raise InsufficientAccountBalanceError(msg)
elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg)
elif err == "invalid_request_error":
raise BadRequestError(msg)
elif "rate" in err:
raise RateLimitReachedError(msg)
elif "internal" in err:
raise InternalServerError(msg)
elif err == "api_key_empty":
raise InvalidAPIKeyError(msg)
else:
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
if stream:
return response.iter_lines()
else:
return response.json()

@ -1,22 +0,0 @@
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalanceError(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass

@ -1,296 +0,0 @@
import json
from collections.abc import Generator, Iterator
from typing import cast
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalanceError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
class BaichuanLanguageModel(LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stream=stream,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(
self,
messages: list[PromptMessage],
) -> int:
"""Calculate num tokens for baichuan model"""
def tokens(text: str):
return BaichuanTokenizer._get_num_tokens(text)
tokens_per_message = 3
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
num_tokens += tokens(str(value))
num_tokens += 3
return num_tokens
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for Baichuan
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def validate_credentials(self, model: str, credentials: dict) -> None:
# ping
instance = BaichuanModel(api_key=credentials["api_key"])
try:
instance.generate(
model=model,
stream=False,
messages=[{"content": "ping", "role": "user"}],
parameters={
"max_tokens": 1,
},
timeout=60,
)
except Exception as e:
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stream: bool = True,
) -> LLMResult | Generator:
instance = BaichuanModel(api_key=credentials["api_key"])
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
# invoke model
response = instance.generate(
model=model,
stream=stream,
messages=messages,
parameters=model_parameters,
timeout=60,
tools=tools,
)
if stream:
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: dict,
) -> LLMResult:
choices = response.get("choices", [])
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if choices and choices[0]["finish_reason"] == "tool_calls":
for choice in choices:
for tool_call in choice["message"]["tool_calls"]:
tool = AssistantPromptMessage.ToolCall(
id=tool_call.get("id", ""),
type=tool_call.get("type", ""),
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call.get("function", {}).get("name", ""),
arguments=tool_call.get("function", {}).get("arguments", ""),
),
)
assistant_message.tool_calls.append(tool)
else:
for choice in choices:
assistant_message.content += choice["message"]["content"]
assistant_message.role = choice["message"]["role"]
usage = response.get("usage")
if usage:
# transform usage
prompt_tokens = usage["prompt_tokens"]
completion_tokens = usage["completion_tokens"]
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(prompt_messages)
completion_tokens = self._num_tokens_from_messages([assistant_message])
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_message,
usage=usage,
)
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
response: Iterator,
) -> Generator:
for line in response:
if not line:
continue
line = line.decode("utf-8")
# remove the first `data: ` prefix
if line.startswith("data:"):
line = line[5:].strip()
try:
data = json.loads(line)
except Exception as e:
if line.strip() == "[DONE]":
return
choices = data.get("choices", [])
stop_reason = ""
for choice in choices:
if choice.get("finish_reason"):
stop_reason = choice["finish_reason"]
if len(choice["delta"]["content"]) == 0:
continue
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=choice["delta"]["content"], tool_calls=[]),
finish_reason=stop_reason,
),
)
# if there is usage, the response is the last one, yield it and return
if "usage" in data:
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=data["usage"]["prompt_tokens"],
completion_tokens=data["usage"]["completion_tokens"],
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content="", tool_calls=[]),
usage=usage,
finish_reason=stop_reason,
),
)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [],
InvokeServerUnavailableError: [InternalServerError],
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalanceError,
InvalidAPIKeyError,
],
InvokeBadRequestError: [BadRequestError, KeyError],
}

@ -1,5 +0,0 @@
model: baichuan-text-embedding
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 16

@ -1,14 +0,0 @@
<svg width="140" height="24" viewBox="0 0 140 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M131.701 17.9999V6.8999H133.876V13.6049L136.531 10.3349H139.141L135.976 13.9949L139.381 17.9999H136.711L133.876 14.5049V17.9999H131.701Z" fill="#252F3E"/>
<path d="M129.847 17.6699C129.577 17.8299 129.252 17.9499 128.872 18.0299C128.492 18.1199 128.097 18.1649 127.687 18.1649C126.467 18.1649 125.532 17.8249 124.882 17.1449C124.242 16.4649 123.922 15.4849 123.922 14.2049C123.922 12.9349 124.262 11.9449 124.942 11.2349C125.622 10.5249 126.567 10.1699 127.777 10.1699C128.507 10.1699 129.182 10.3299 129.802 10.6499V12.1049C129.212 11.9349 128.672 11.8499 128.182 11.8499C127.482 11.8499 126.967 12.0299 126.637 12.3899C126.307 12.7399 126.142 13.2999 126.142 14.0699V14.2799C126.142 15.0399 126.302 15.5999 126.622 15.9599C126.952 16.3099 127.457 16.4849 128.137 16.4849C128.627 16.4849 129.197 16.3949 129.847 16.2149V17.6699Z" fill="#252F3E"/>
<path d="M118.51 18.2249C117.32 18.2249 116.39 17.8699 115.72 17.1599C115.05 16.4399 114.715 15.4399 114.715 14.1599C114.715 12.8899 115.05 11.8999 115.72 11.1899C116.39 10.4699 117.32 10.1099 118.51 10.1099C119.7 10.1099 120.63 10.4699 121.3 11.1899C121.97 11.8999 122.305 12.8899 122.305 14.1599C122.305 15.4399 121.97 16.4399 121.3 17.1599C120.63 17.8699 119.7 18.2249 118.51 18.2249ZM118.51 16.5449C119.56 16.5449 120.085 15.7499 120.085 14.1599C120.085 12.5799 119.56 11.7899 118.51 11.7899C117.46 11.7899 116.935 12.5799 116.935 14.1599C116.935 15.7499 117.46 16.5449 118.51 16.5449Z" fill="#252F3E"/>
<path d="M108.727 17.9998V10.3348H110.527L110.797 11.4748C111.197 11.0348 111.572 10.7248 111.922 10.5448C112.282 10.3548 112.662 10.2598 113.062 10.2598C113.252 10.2598 113.452 10.2748 113.662 10.3048V12.3298C113.382 12.2698 113.072 12.2398 112.732 12.2398C112.082 12.2398 111.477 12.3548 110.917 12.5848V17.9998H108.727Z" fill="#252F3E"/>
<path d="M104.417 17.9999L104.237 17.3249C103.617 17.8849 102.882 18.1649 102.032 18.1649C101.402 18.1649 100.847 18.0099 100.367 17.6999C99.8866 17.3799 99.5116 16.9199 99.2416 16.3199C98.9816 15.7199 98.8516 15.0149 98.8516 14.2049C98.8516 12.9649 99.1466 11.9749 99.7366 11.2349C100.327 10.4849 101.107 10.1099 102.077 10.1099C102.867 10.1099 103.552 10.3349 104.132 10.7849V6.8999H106.322V17.9999H104.417ZM102.752 16.5149C103.232 16.5149 103.692 16.3749 104.132 16.0949V12.1349C103.702 11.8849 103.207 11.7599 102.647 11.7599C102.117 11.7599 101.722 11.9599 101.462 12.3599C101.202 12.7499 101.072 13.3449 101.072 14.1449C101.072 14.9449 101.207 15.5399 101.477 15.9299C101.757 16.3199 102.182 16.5149 102.752 16.5149Z" fill="#252F3E"/>
<path d="M92.4625 14.6999C92.5025 15.3599 92.7025 15.8399 93.0625 16.1399C93.4225 16.4299 93.9875 16.5749 94.7575 16.5749C95.4275 16.5749 96.2075 16.4499 97.0975 16.1999V17.6549C96.7475 17.8349 96.3275 17.9749 95.8375 18.0749C95.3575 18.1749 94.8575 18.2249 94.3375 18.2249C93.0675 18.2249 92.0975 17.8799 91.4275 17.1899C90.7675 16.4999 90.4375 15.4899 90.4375 14.1599C90.4375 12.8799 90.7675 11.8849 91.4275 11.1749C92.0875 10.4649 93.0025 10.1099 94.1725 10.1099C95.1625 10.1099 95.9225 10.3849 96.4525 10.9349C96.9925 11.4749 97.2625 12.2499 97.2625 13.2599C97.2625 13.4799 97.2475 13.7299 97.2175 14.0099C97.1875 14.2899 97.1525 14.5199 97.1125 14.6999H92.4625ZM94.0975 11.6249C93.6075 11.6249 93.2175 11.7749 92.9275 12.0749C92.6475 12.3649 92.4875 12.7899 92.4475 13.3499H95.3875V13.0949C95.3875 12.1149 94.9575 11.6249 94.0975 11.6249Z" fill="#252F3E"/>
<path d="M81.1992 18V7.60498H84.9342C85.9342 7.60498 86.7392 7.85998 87.3492 8.36998C87.9692 8.87998 88.2792 9.54498 88.2792 10.365C88.2792 10.875 88.1592 11.315 87.9192 11.685C87.6892 12.045 87.3442 12.325 86.8842 12.525C87.5242 12.715 88.0092 13.03 88.3392 13.47C88.6792 13.9 88.8492 14.43 88.8492 15.06C88.8492 15.96 88.5142 16.675 87.8442 17.205C87.1742 17.735 86.2742 18 85.1442 18H81.1992ZM83.3292 13.47V16.395H85.0992C86.1192 16.395 86.6292 15.915 86.6292 14.955C86.6292 13.965 86.0842 13.47 84.9942 13.47H83.3292ZM83.3292 9.20998V11.94H84.6342C85.6042 11.94 86.0892 11.49 86.0892 10.59C86.0892 9.66998 85.6442 9.20998 84.7542 9.20998H83.3292Z" fill="#252F3E"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M62.0002 20.4425L58.6644 21.5548L57.3636 20.6872L58.7799 20.2142L58.3454 18.9107L55.9143 19.7206L55.1251 19.1953V15.4374C55.1251 15.1775 54.9779 14.9396 54.7456 14.8227L52.375 13.6375V10.3621L54.4376 9.33087L56.5001 10.3621V12.6873C56.5001 12.9486 56.6472 13.1864 56.8796 13.3033L59.6297 14.6783L60.2457 13.4477L57.8751 12.2624V10.3621L60.2457 9.17824C60.4781 9.06136 60.6252 8.82349 60.6252 8.56223V6.49969H59.2502V8.13735L57.1876 9.16862L55.1251 8.13735V4.80566L56.5001 3.88852V6.49969H57.8751V2.97275L58.6644 2.44612L62.0002 3.55851V20.4425ZM69.5628 18.8749C69.941 18.8749 70.2504 19.1829 70.2504 19.5624C70.2504 19.9419 69.941 20.2499 69.5628 20.2499C69.1847 20.2499 68.8753 19.9419 68.8753 19.5624C68.8753 19.1829 69.1847 18.8749 69.5628 18.8749ZM68.1878 3.74964C68.566 3.74964 68.8753 4.05765 68.8753 4.43715C68.8753 4.81666 68.566 5.12467 68.1878 5.12467C67.8097 5.12467 67.5003 4.81666 67.5003 4.43715C67.5003 4.05765 67.8097 3.74964 68.1878 3.74964ZM70.9379 11.9998C71.316 11.9998 71.6254 12.3078 71.6254 12.6873C71.6254 13.0668 71.316 13.3748 70.9379 13.3748C70.5597 13.3748 70.2504 13.0668 70.2504 12.6873C70.2504 12.3078 70.5597 11.9998 70.9379 11.9998ZM69.0018 13.3748C69.2865 14.1737 70.0427 14.7498 70.9379 14.7498C72.075 14.7498 73.0004 13.8258 73.0004 12.6873C73.0004 11.5502 72.075 10.6248 70.9379 10.6248C70.0427 10.6248 69.2865 11.2023 69.0018 11.9998H63.3752V9.24974H68.1878C68.5673 9.24974 68.8753 8.94311 68.8753 8.56223V6.37319C69.6742 6.08856 70.2504 5.3323 70.2504 4.43715C70.2504 3.30001 69.325 2.37462 68.1878 2.37462C67.0507 2.37462 66.1253 3.30001 66.1253 4.43715C66.1253 5.3323 66.7014 6.08856 67.5003 6.37319V7.87472H63.3752V3.06213C63.3752 2.7665 63.1855 2.50387 62.905 2.41037L58.7799 1.03534C58.5778 0.967964 58.3578 0.998214 58.1818 1.11509L54.0567 3.86514C53.8656 3.99302 53.7501 4.20752 53.7501 4.43715V8.13735L51.3795 9.32262C51.1471 9.4395 51 9.67738 51 9.93726V14.0623C51 14.3236 51.1471 14.5615 51.3795 14.6783L53.7501 15.8622V19.5624C53.7501 19.7921 53.8656 20.0079 54.0567 20.1344L58.1818 22.8845C58.2959 22.9615 58.4279 23 58.5626 23C58.6355 23 58.7084 22.989 58.7799 22.9642L62.905 21.5892C63.1855 21.4971 63.3752 21.2345 63.3752 20.9375V17.4999H66.5282L67.7011 18.6742L67.7189 18.6563C67.5842 18.9313 67.5003 19.2366 67.5003 19.5624C67.5003 20.6996 68.4257 21.625 69.5628 21.625C70.7 21.625 71.6254 20.6996 71.6254 19.5624C71.6254 18.4253 70.7 17.4999 69.5628 17.4999C69.2356 17.4999 68.9303 17.5838 68.6567 17.7199L68.6746 17.702L67.2996 16.327C67.1703 16.1977 66.9957 16.1249 66.8128 16.1249H63.3752V13.3748H69.0018Z" fill="#252F3E"/>
<line x1="43.25" y1="4" x2="43.25" y2="20" stroke="black" stroke-opacity="0.08" stroke-width="0.5"/>
<path d="M9.89554 9.62679C9.89554 10.0589 9.94226 10.4093 10.024 10.6663C10.1175 10.9232 10.2342 11.2035 10.3978 11.5072C10.4562 11.6006 10.4795 11.6941 10.4795 11.7758C10.4795 11.8926 10.4094 12.0094 10.2576 12.1262L9.52179 12.6168C9.41667 12.6869 9.31156 12.7219 9.21812 12.7219C9.10132 12.7219 8.98453 12.6635 8.86773 12.5584C8.70422 12.3832 8.56406 12.1963 8.44726 12.0094C8.33047 11.8109 8.21367 11.589 8.0852 11.3203C7.17419 12.3949 6.02958 12.9321 4.65139 12.9321C3.6703 12.9321 2.88777 12.6518 2.31546 12.0912C1.74316 11.5306 1.45117 10.7831 1.45117 9.84871C1.45117 8.85594 1.80156 8.05004 2.51402 7.4427C3.22647 6.83536 4.17252 6.53169 5.37552 6.53169C5.77263 6.53169 6.18142 6.56673 6.61356 6.62513C7.04571 6.68353 7.48954 6.77697 7.95672 6.88208V6.02947C7.95672 5.14182 7.76985 4.5228 7.40778 4.16073C7.03403 3.79866 6.40333 3.62347 5.504 3.62347C5.09521 3.62347 4.67475 3.67019 4.2426 3.7753C3.81046 3.88042 3.38999 4.00889 2.9812 4.17241C2.79433 4.25417 2.65417 4.30089 2.57242 4.32424C2.49066 4.3476 2.43226 4.35928 2.38554 4.35928C2.22203 4.35928 2.14027 4.24249 2.14027 3.99722V3.42491C2.14027 3.23804 2.16363 3.09788 2.22203 3.01613C2.28042 2.93437 2.38554 2.85261 2.54906 2.77085C2.95784 2.56062 3.44839 2.38543 4.02069 2.24527C4.59299 2.09344 5.20033 2.02336 5.84271 2.02336C7.23258 2.02336 8.24871 2.33871 8.90277 2.96941C9.54515 3.60011 9.87218 4.55784 9.87218 5.8426V9.62679H9.89554ZM5.15361 11.4021C5.53904 11.4021 5.93615 11.332 6.35661 11.1919C6.77708 11.0517 7.15083 10.7948 7.46618 10.4444C7.65305 10.2225 7.79321 9.97718 7.86328 9.69687C7.93336 9.41656 7.98008 9.07785 7.98008 8.68074V8.1902C7.64137 8.10844 7.2793 8.03836 6.90555 7.99165C6.53181 7.94493 6.16974 7.92157 5.80767 7.92157C5.02514 7.92157 4.45283 8.0734 4.06741 8.38875C3.68198 8.7041 3.49511 9.14793 3.49511 9.73191C3.49511 10.2809 3.63526 10.6896 3.92725 10.9699C4.20756 11.2619 4.61635 11.4021 5.15361 11.4021ZM14.5323 12.6635C14.3221 12.6635 14.182 12.6285 14.0885 12.5467C13.9951 12.4766 13.9133 12.3131 13.8432 12.0912L11.0985 3.06285C11.0285 2.82925 10.9934 2.67742 10.9934 2.59566C10.9934 2.40879 11.0869 2.30367 11.2737 2.30367H12.4183C12.6402 2.30367 12.7921 2.33871 12.8738 2.42047C12.9673 2.49054 13.0374 2.65406 13.1074 2.87597L15.0696 10.6079L16.8916 2.87597C16.95 2.64238 17.0201 2.49054 17.1135 2.42047C17.207 2.35039 17.3705 2.30367 17.5807 2.30367H18.5151C18.737 2.30367 18.8888 2.33871 18.9823 2.42047C19.0757 2.49054 19.1575 2.65406 19.2042 2.87597L21.0496 10.7013L23.0701 2.87597C23.1402 2.64238 23.222 2.49054 23.3037 2.42047C23.3972 2.35039 23.549 2.30367 23.7592 2.30367H24.8454C25.0323 2.30367 25.1374 2.39711 25.1374 2.59566C25.1374 2.65406 25.1258 2.71246 25.1141 2.78253C25.1024 2.85261 25.079 2.94605 25.0323 3.07453L22.2175 12.1029C22.1475 12.3365 22.0657 12.4883 21.9723 12.5584C21.8788 12.6285 21.727 12.6752 21.5284 12.6752H20.524C20.3021 12.6752 20.1502 12.6401 20.0568 12.5584C19.9634 12.4766 19.8816 12.3248 19.8349 12.0912L18.0246 4.55784L16.2259 12.0795C16.1675 12.3131 16.0974 12.4649 16.004 12.5467C15.9105 12.6285 15.747 12.6635 15.5368 12.6635H14.5323ZM29.5407 12.9788C28.9333 12.9788 28.326 12.9088 27.742 12.7686C27.158 12.6285 26.7025 12.4766 26.3988 12.3014C26.212 12.1963 26.0835 12.0795 26.0368 11.9744C25.9901 11.8693 25.9667 11.7525 25.9667 11.6474V11.0517C25.9667 10.8064 26.0601 10.6896 26.2353 10.6896C26.3054 10.6896 26.3755 10.7013 26.4456 10.7247C26.5156 10.748 26.6208 10.7948 26.7375 10.8415C27.1347 11.0167 27.5668 11.1568 28.0223 11.2503C28.4895 11.3437 28.945 11.3904 29.4122 11.3904C30.148 11.3904 30.7203 11.2619 31.1174 11.005C31.5145 10.748 31.7247 10.3743 31.7247 9.89542C31.7247 9.56839 31.6196 9.29976 31.4094 9.07785C31.1992 8.85594 30.8021 8.65738 30.2298 8.47051L28.5362 7.94493C27.6836 7.6763 27.0529 7.27919 26.6675 6.75361C26.282 6.2397 26.0835 5.6674 26.0835 5.06006C26.0835 4.56952 26.1886 4.13737 26.3988 3.76362C26.6091 3.38987 26.8894 3.06285 27.2398 2.80589C27.5902 2.53726 27.9873 2.33871 28.4545 2.19855C28.9216 2.0584 29.4122 2 29.9261 2C30.183 2 30.4517 2.01168 30.7086 2.04672C30.9773 2.08176 31.2225 2.12848 31.4678 2.17519C31.7014 2.23359 31.9233 2.29199 32.1335 2.36207C32.3438 2.43215 32.5073 2.50222 32.6241 2.5723C32.7876 2.66574 32.9044 2.75918 32.9745 2.86429C33.0445 2.95773 33.0796 3.0862 33.0796 3.24972V3.79866C33.0796 4.04393 32.9861 4.17241 32.811 4.17241C32.7175 4.17241 32.5657 4.12569 32.3671 4.03225C31.7014 3.72858 30.9539 3.57675 30.1246 3.57675C29.4589 3.57675 28.9333 3.68187 28.5712 3.90378C28.2092 4.12569 28.0223 4.4644 28.0223 4.94326C28.0223 5.27029 28.1391 5.5506 28.3727 5.77252C28.6063 5.99443 29.0384 6.21634 29.6575 6.4149L31.316 6.94048C32.1569 7.20911 32.7642 7.58286 33.1263 8.06172C33.4884 8.54059 33.6636 9.08953 33.6636 9.69687C33.6636 10.1991 33.5584 10.6546 33.3599 11.0517C33.1497 11.4488 32.8693 11.7992 32.5073 12.0795C32.1452 12.3715 31.7131 12.5817 31.2108 12.7336C30.6853 12.8971 30.1363 12.9788 29.5407 12.9788Z" fill="#252F3E"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M31.749 18.6553C27.9064 21.4934 22.3235 23.0001 17.5232 23.0001C10.7957 23.0001 4.73399 20.5123 0.155575 16.3778C-0.206494 16.0507 0.120536 15.6069 0.552682 15.8639C5.50484 18.737 11.6133 20.4773 17.932 20.4773C22.195 20.4773 26.8786 19.5896 31.1883 17.7676C31.8307 17.4756 32.3797 18.1881 31.749 18.6553Z" fill="#FF9900"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M33.3507 16.833C32.8601 16.2023 30.1037 16.5293 28.854 16.6811C28.4803 16.7278 28.4219 16.4008 28.7606 16.1555C30.9564 14.6138 34.5654 15.0577 34.9858 15.5716C35.4063 16.0971 34.869 19.7062 32.8134 21.4347C32.4981 21.7034 32.1944 21.5632 32.3345 21.2128C32.8017 20.0565 33.8412 17.452 33.3507 16.833Z" fill="#FF9900"/>
</svg>

Before

Width:  |  Height:  |  Size: 12 KiB

@ -1,15 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_16762_59518)">
<path d="M12.6667 0H3.33333C1.49238 0 0 1.49238 0 3.33333V12.6667C0 14.5076 1.49238 16 3.33333 16H12.6667C14.5076 16 16 14.5076 16 12.6667V3.33333C16 1.49238 14.5076 0 12.6667 0Z" fill="url(#paint0_linear_16762_59518)"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.99984 12.093L6.3825 12.6323L5.75184 12.2116L6.4385 11.9823L6.22784 11.3503L5.04917 11.743L4.6665 11.4883V9.66631C4.6665 9.54031 4.59517 9.42497 4.4825 9.3683L3.33317 8.79364V7.20564L4.33317 6.70564L5.33317 7.20564V8.33297C5.33317 8.45964 5.4045 8.57497 5.51717 8.63164L6.8505 9.29831L7.14917 8.70164L5.99984 8.12697V7.20564L7.14917 6.63164C7.26184 6.57497 7.33317 6.45964 7.33317 6.33297V5.33297H6.6665V6.12697L5.6665 6.62697L4.6665 6.12697V4.51164L5.33317 4.06697V5.33297H5.99984V3.62297L6.3825 3.36764L7.99984 3.90697V12.093ZM11.6665 11.333C11.8498 11.333 11.9998 11.4823 11.9998 11.6663C11.9998 11.8503 11.8498 11.9996 11.6665 11.9996C11.4832 11.9996 11.3332 11.8503 11.3332 11.6663C11.3332 11.4823 11.4832 11.333 11.6665 11.333ZM10.9998 3.99964C11.1832 3.99964 11.3332 4.14897 11.3332 4.33297C11.3332 4.51697 11.1832 4.6663 10.9998 4.6663C10.8165 4.6663 10.6665 4.51697 10.6665 4.33297C10.6665 4.14897 10.8165 3.99964 10.9998 3.99964ZM12.3332 7.99964C12.5165 7.99964 12.6665 8.14897 12.6665 8.33297C12.6665 8.51697 12.5165 8.66631 12.3332 8.66631C12.1498 8.66631 11.9998 8.51697 11.9998 8.33297C11.9998 8.14897 12.1498 7.99964 12.3332 7.99964ZM11.3945 8.66631C11.5325 9.05364 11.8992 9.33297 12.3332 9.33297C12.8845 9.33297 13.3332 8.88497 13.3332 8.33297C13.3332 7.78164 12.8845 7.33297 12.3332 7.33297C11.8992 7.33297 11.5325 7.61297 11.3945 7.99964H8.6665V6.66631H10.9998C11.1838 6.66631 11.3332 6.51764 11.3332 6.33297V5.27164C11.7205 5.13364 11.9998 4.76697 11.9998 4.33297C11.9998 3.78164 11.5512 3.33297 10.9998 3.33297C10.4485 3.33297 9.99984 3.78164 9.99984 4.33297C9.99984 4.76697 10.2792 5.13364 10.6665 5.27164V5.99964H8.6665V3.6663C8.6665 3.52297 8.5745 3.39564 8.4385 3.3503L6.4385 2.68364C6.3405 2.65097 6.23384 2.66564 6.1485 2.7223L4.1485 4.05564C4.05584 4.11764 3.99984 4.22164 3.99984 4.33297V6.12697L2.8505 6.70164C2.73784 6.75831 2.6665 6.87364 2.6665 6.99964V8.99964C2.6665 9.12631 2.73784 9.24164 2.8505 9.29831L3.99984 9.87231V11.6663C3.99984 11.7776 4.05584 11.8823 4.1485 11.9436L6.1485 13.277C6.20384 13.3143 6.26784 13.333 6.33317 13.333C6.3685 13.333 6.40384 13.3276 6.4385 13.3156L8.4385 12.649C8.5745 12.6043 8.6665 12.477 8.6665 12.333V10.6663H10.1952L10.7638 11.2356L10.7725 11.227C10.7072 11.3603 10.6665 11.5083 10.6665 11.6663C10.6665 12.2176 11.1152 12.6663 11.6665 12.6663C12.2178 12.6663 12.6665 12.2176 12.6665 11.6663C12.6665 11.115 12.2178 10.6663 11.6665 10.6663C11.5078 10.6663 11.3598 10.707 11.2272 10.773L11.2358 10.7643L10.5692 10.0976C10.5065 10.035 10.4218 9.99964 10.3332 9.99964H8.6665V8.66631H11.3945Z" fill="white"/>
</g>
<defs>
<linearGradient id="paint0_linear_16762_59518" x1="0" y1="1600" x2="1600" y2="0" gradientUnits="userSpaceOnUse">
<stop stop-color="#055F4E"/>
<stop offset="1" stop-color="#56C0A7"/>
</linearGradient>
<clipPath id="clip0_16762_59518">
<rect width="16" height="16" fill="white"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 3.2 KiB

@ -1,29 +0,0 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class BedrockProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `amazon.titan-text-lite-v1` model by default for validating credentials
model_for_validation = credentials.get("model_for_validation", "amazon.titan-text-lite-v1")
model_instance.validate_credentials(model=model_for_validation, credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

@ -1,89 +0,0 @@
provider: bedrock
label:
en_US: AWS
description:
en_US: AWS Bedrock's models.
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#FCFDFF"
help:
title:
en_US: Get your Access Key and Secret Access Key from AWS Console
url:
en_US: https://console.aws.amazon.com/
supported_model_types:
- llm
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: aws_access_key_id
required: false
label:
en_US: Access Key (If not provided, credentials are obtained from the running environment.)
zh_Hans: Access Key
type: secret-input
placeholder:
en_US: Enter your Access Key
zh_Hans: 在此输入您的 Access Key
- variable: aws_secret_access_key
required: false
label:
en_US: Secret Access Key
zh_Hans: Secret Access Key
type: secret-input
placeholder:
en_US: Enter your Secret Access Key
zh_Hans: 在此输入您的 Secret Access Key
- variable: aws_region
required: true
label:
en_US: AWS Region
zh_Hans: AWS 地区
type: select
default: us-east-1
options:
- value: us-east-1
label:
en_US: US East (N. Virginia)
zh_Hans: 美国东部 (弗吉尼亚北部)
- value: us-west-2
label:
en_US: US West (Oregon)
zh_Hans: 美国西部 (俄勒冈州)
- value: ap-southeast-1
label:
en_US: Asia Pacific (Singapore)
zh_Hans: 亚太地区 (新加坡)
- value: ap-northeast-1
label:
en_US: Asia Pacific (Tokyo)
zh_Hans: 亚太地区 (东京)
- value: eu-central-1
label:
en_US: Europe (Frankfurt)
zh_Hans: 欧洲 (法兰克福)
- value: eu-west-2
label:
en_US: Eu west London (London)
zh_Hans: 欧洲西部 (伦敦)
- value: us-gov-west-1
label:
en_US: AWS GovCloud (US-West)
zh_Hans: AWS GovCloud (US-West)
- value: ap-southeast-2
label:
en_US: Asia Pacific (Sydney)
zh_Hans: 亚太地区 (悉尼)
- variable: model_for_validation
required: false
label:
en_US: Available Model Name
zh_Hans: 可用模型名称
type: text-input
placeholder:
en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如amazon.titan-text-lite-v1)

@ -1,24 +0,0 @@
- amazon.titan-text-express-v1
- amazon.titan-text-lite-v1
- anthropic.claude-instant-v1
- anthropic.claude-v1
- anthropic.claude-v2
- anthropic.claude-v2:1
- anthropic.claude-3-sonnet-v1:0
- anthropic.claude-3-haiku-v1:0
- cohere.command-light-text-v14
- cohere.command-text-v14
- cohere.command-r-plus-v1.0
- cohere.command-r-v1.0
- meta.llama3-1-8b-instruct-v1:0
- meta.llama3-1-70b-instruct-v1:0
- meta.llama3-1-405b-instruct-v1:0
- meta.llama3-8b-instruct-v1:0
- meta.llama3-70b-instruct-v1:0
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
- mistral.mistral-large-2407-v1:0
- mistral.mistral-small-2402-v1:0
- mistral.mistral-large-2402-v1:0
- mistral.mixtral-8x7b-instruct-v0:1
- mistral.mistral-7b-instruct-v0:2

@ -1,47 +0,0 @@
model: ai21.j2-mid-v1
label:
en_US: J2 Mid V1
model_type: llm
model_properties:
mode: completion
context_size: 8191
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: maxTokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
- name: count_penalty
label:
en_US: Count Penalty
required: false
type: float
default: 0
min: 0
max: 1
- name: presence_penalty
label:
en_US: Presence Penalty
required: false
type: float
default: 0
min: 0
max: 5
- name: frequency_penalty
label:
en_US: Frequency Penalty
required: false
type: float
default: 0
min: 0
max: 500
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -1,47 +0,0 @@
model: ai21.j2-ultra-v1
label:
en_US: J2 Ultra V1
model_type: llm
model_properties:
mode: completion
context_size: 8191
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: maxTokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
- name: count_penalty
label:
en_US: Count Penalty
required: false
type: float
default: 0
min: 0
max: 1
- name: presence_penalty
label:
en_US: Presence Penalty
required: false
type: float
default: 0
min: 0
max: 5
- name: frequency_penalty
label:
en_US: Frequency Penalty
required: false
type: float
default: 0
min: 0
max: 500
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -1,23 +0,0 @@
model: amazon.titan-text-express-v1
label:
en_US: Titan Text G1 - Express
model_type: llm
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: maxTokenCount
use_template: max_tokens
required: true
default: 2048
min: 1
max: 8000
pricing:
input: '0.0008'
output: '0.0016'
unit: '0.001'
currency: USD

@ -1,23 +0,0 @@
model: amazon.titan-text-lite-v1
label:
en_US: Titan Text G1 - Lite
model_type: llm
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: topP
use_template: top_p
- name: maxTokenCount
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
pricing:
input: '0.0003'
output: '0.0004'
unit: '0.001'
currency: USD

@ -1,61 +0,0 @@
model: anthropic.claude-3-haiku-20240307-v1:0
label:
en_US: Claude 3 Haiku
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.00025'
output: '0.00125'
unit: '0.001'
currency: USD

@ -1,61 +0,0 @@
model: anthropic.claude-3-opus-20240229-v1:0
label:
en_US: Claude 3 Opus
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.015'
output: '0.075'
unit: '0.001'
currency: USD

@ -1,60 +0,0 @@
model: anthropic.claude-3-5-sonnet-20240620-v1:0
label:
en_US: Claude 3.5 Sonnet
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

@ -1,60 +0,0 @@
model: anthropic.claude-3-sonnet-20240229-v1:0
label:
en_US: Claude 3 Sonnet
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

@ -1,52 +0,0 @@
model: anthropic.claude-instant-v1
label:
en_US: Claude Instant 1
model_type: llm
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.0008'
output: '0.0024'
unit: '0.001'
currency: USD

@ -1,53 +0,0 @@
model: anthropic.claude-v1
label:
en_US: Claude 1
model_type: llm
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD
deprecated: true

@ -1,54 +0,0 @@
model: anthropic.claude-v2:1
label:
en_US: Claude 2.1
model_type: llm
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD

@ -1,54 +0,0 @@
model: anthropic.claude-v2
label:
en_US: Claude 2
model_type: llm
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD

@ -1,35 +0,0 @@
model: cohere.command-light-text-v14
label:
en_US: Command Light Text V14
model_type: llm
model_properties:
mode: completion
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: p
use_template: top_p
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
min: 0
max: 500
default: 0
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.0003'
output: '0.0006'
unit: '0.001'
currency: USD

@ -1,44 +0,0 @@
model: cohere.command-r-plus-v1:0
label:
en_US: Command R+
model_type: llm
features:
- tool-call
#- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 1024
max: 4096
pricing:
input: '3'
output: '15'
unit: '0.000001'
currency: USD

@ -1,43 +0,0 @@
model: cohere.command-r-v1:0
label:
en_US: Command R
model_type: llm
features:
- tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 1024
max: 4096
pricing:
input: '0.5'
output: '1.5'
unit: '0.000001'
currency: USD

@ -1,32 +0,0 @@
model: cohere.command-text-v14
label:
en_US: Command Text V14
model_type: llm
model_properties:
mode: completion
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: p
use_template: top_p
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
pricing:
input: '0.0015'
output: '0.0020'
unit: '0.001'
currency: USD

@ -1,59 +0,0 @@
model: eu.anthropic.claude-3-haiku-20240307-v1:0
label:
en_US: Claude 3 Haiku(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.00025'
output: '0.00125'
unit: '0.001'
currency: USD

@ -1,58 +0,0 @@
model: eu.anthropic.claude-3-5-sonnet-20240620-v1:0
label:
en_US: Claude 3.5 Sonnet(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

@ -1,58 +0,0 @@
model: eu.anthropic.claude-3-sonnet-20240229-v1:0
label:
en_US: Claude 3 Sonnet(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

@ -1,903 +0,0 @@
# standard import
import base64
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
# 3rd import
import boto3
import requests
from botocore.config import Config
from botocore.exceptions import (
ClientError,
EndpointConnectionError,
NoRegionError,
ServiceNotInRegionError,
UnknownServiceError,
)
# local import
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
""" # noqa: E501
class BedrockLargeLanguageModel(LargeLanguageModel):
# please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
# TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
CONVERSE_API_ENABLED_MODEL_INFO = [
{"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "us.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "eu.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False},
{"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False},
{"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False},
{"prefix": "mistral.mistral-large", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True},
{"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False},
]
@staticmethod
def _find_model_info(model_id):
for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO:
if model_id.startswith(model["prefix"]):
return model
logger.info(f"current model id: {model_id} did not support by Converse API")
return None
def _code_block_mode_wrapper(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if model_parameters.get("response_format"):
stop = stop or []
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
response_format = model_parameters.pop("response_format")
format_prompt = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
"{{block}}", response_format
)
)
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
prompt_messages[0] = format_prompt
else:
prompt_messages.insert(0, format_prompt)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
model_info = BedrockLargeLanguageModel._find_model_info(model)
if model_info:
model_info["model"] = model
# invoke models via boto3 converse API
return self._generate_with_converse(
model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools
)
# invoke other models via boto3 client
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _generate_with_converse(
self,
model_info: dict,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
tools: Optional[list[PromptMessageTool]] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model with converse API
:param model_info: model information
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
bedrock_client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
region_name=credentials["aws_region"],
)
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
parameters = {
"modelId": model_info["model"],
"messages": prompt_message_dicts,
"inferenceConfig": inference_config,
"additionalModelRequestFields": additional_model_fields,
}
if model_info["support_system_prompts"] and system and len(system) > 0:
parameters["system"] = system
if model_info["support_tool_use"] and tools:
parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools)
try:
if stream:
response = bedrock_client.converse_stream(**parameters)
return self._handle_converse_stream_response(
model_info["model"], credentials, response, prompt_messages
)
else:
response = bedrock_client.converse(**parameters)
return self._handle_converse_response(model_info["model"], credentials, response, prompt_messages)
except ClientError as ex:
error_code = ex.response["Error"]["Code"]
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise self._map_client_to_invoke_error(error_code, full_error_msg)
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
raise InvokeConnectionError(str(ex))
except UnknownServiceError as ex:
raise InvokeServerUnavailableError(str(ex))
except Exception as ex:
raise InvokeError(str(ex))
def _handle_converse_response(
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: full response chunk generator result
"""
response_content = response["output"]["message"]["content"]
# transform assistant message to prompt message
if response["stopReason"] == "tool_use":
tool_calls = []
text, tool_use = self._extract_tool_use(response_content)
tool_call = AssistantPromptMessage.ToolCall(
id=tool_use["toolUseId"],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_use["name"], arguments=json.dumps(tool_use["input"])
),
)
tool_calls.append(tool_call)
assistant_prompt_message = AssistantPromptMessage(content=text, tool_calls=tool_calls)
else:
assistant_prompt_message = AssistantPromptMessage(content=response_content[0]["text"])
# calculate num tokens
if response["usage"]:
# transform usage
prompt_tokens = response["usage"]["inputTokens"]
completion_tokens = response["usage"]["outputTokens"]
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
)
return result
def _extract_tool_use(self, content: dict) -> tuple[str, dict]:
tool_use = {}
text = ""
for item in content:
if "toolUse" in item:
tool_use = item["toolUse"]
elif "text" in item:
text = item["text"]
else:
raise ValueError(f"Got unknown item: {item}")
return text, tool_use
def _handle_converse_stream_response(
self,
model: str,
credentials: dict,
response: dict,
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm chat stream response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: full response or stream response chunk generator result
"""
try:
full_assistant_content = ""
return_model = None
input_tokens = 0
output_tokens = 0
finish_reason = None
index = 0
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_use = {}
for chunk in response["stream"]:
if "messageStart" in chunk:
return_model = model
elif "messageStop" in chunk:
finish_reason = chunk["messageStop"]["stopReason"]
elif "contentBlockStart" in chunk:
tool = chunk["contentBlockStart"]["start"]["toolUse"]
tool_use["toolUseId"] = tool["toolUseId"]
tool_use["name"] = tool["name"]
elif "metadata" in chunk:
input_tokens = chunk["metadata"]["usage"]["inputTokens"]
output_tokens = chunk["metadata"]["usage"]["outputTokens"]
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content="", tool_calls=tool_calls),
finish_reason=finish_reason,
usage=usage,
),
)
elif "contentBlockDelta" in chunk:
delta = chunk["contentBlockDelta"]["delta"]
if "text" in delta:
chunk_text = delta["text"] or ""
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text or "",
)
index = chunk["contentBlockDelta"]["contentBlockIndex"]
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index + 1,
message=assistant_prompt_message,
),
)
elif "toolUse" in delta:
if "input" not in tool_use:
tool_use["input"] = ""
tool_use["input"] += delta["toolUse"]["input"]
elif "contentBlockStop" in chunk:
if "input" in tool_use:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_use["toolUseId"],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_use["name"], arguments=tool_use["input"]
),
)
tool_calls.append(tool_call)
tool_use = {}
except Exception as ex:
raise InvokeError(str(ex))
def _convert_converse_api_model_parameters(
self, model_parameters: dict, stop: Optional[list[str]] = None
) -> tuple[dict, dict]:
inference_config = {}
additional_model_fields = {}
if "max_tokens" in model_parameters:
inference_config["maxTokens"] = model_parameters["max_tokens"]
if "temperature" in model_parameters:
inference_config["temperature"] = model_parameters["temperature"]
if "top_p" in model_parameters:
inference_config["topP"] = model_parameters["temperature"]
if stop:
inference_config["stopSequences"] = stop
if "top_k" in model_parameters:
additional_model_fields["top_k"] = model_parameters["top_k"]
return inference_config, additional_model_fields
def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
"""
Convert prompt messages to dict list and system
"""
system = []
prompt_message_dicts = []
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
message.content = message.content.strip()
system.append({"text": message.content})
else:
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
return system, prompt_message_dicts
def _convert_converse_tool_config(self, tools: Optional[list[PromptMessageTool]] = None) -> dict:
tool_config = {}
configs = []
if tools:
for tool in tools:
configs.append(
{
"toolSpec": {
"name": tool.name,
"description": tool.description,
"inputSchema": {"json": tool.parameters},
}
}
)
tool_config["tools"] = configs
return tool_config
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": [{"text": message.content}]}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {"text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
url = message_content.data
image_content = requests.get(url).content
if "?" in url:
url = url.split("?")[0]
mime_type, _ = mimetypes.guess_type(url)
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp"
)
sub_message_dict = {
"image": {"format": mime_type.replace("image/", ""), "source": {"bytes": image_content}}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
if message.tool_calls:
message_dict = {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": message.tool_calls[0].id,
"name": message.tool_calls[0].function.name,
"input": json.loads(message.tool_calls[0].function.arguments),
}
}
],
}
else:
message_dict = {"role": "assistant", "content": [{"text": message.content}]}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = [{"text": message.content}]
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": message.tool_call_id,
"content": [{"json": {"text": message.content}}],
}
}
],
}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages or message string
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prefix = model.split(".")[0]
model_name = model.split(".")[1]
if isinstance(prompt_messages, str):
prompt = prompt_messages
else:
prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
return self._get_num_tokens_by_gpt2(prompt)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
required_params = {}
if "anthropic" in model:
required_params = {
"max_tokens": 32,
}
elif "ai21" in model:
# ValidationException: Malformed input request: #/temperature: expected type: Number,
# found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null,
# please reformat your input and try again.
required_params = {
"temperature": 0.7,
"topP": 0.9,
"maxTokens": 32,
}
try:
ping_message = UserPromptMessage(content="ping")
self._invoke(
model=model,
credentials=credentials,
prompt_messages=[ping_message],
model_parameters=required_params,
stream=False,
)
except ClientError as ex:
error_code = ex.response["Error"]["Code"]
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg)))
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _convert_one_message_to_text(
self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None
) -> str:
"""
Convert a single message to a string.
:param message: PromptMessage to convert.
:return: String representation of the message.
"""
human_prompt_prefix = ""
human_prompt_postfix = ""
ai_prompt = ""
content = message.content
if isinstance(message, UserPromptMessage):
body = content
if isinstance(content, list):
body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT])
message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
elif isinstance(message, ToolPromptMessage):
message_text = f"{human_prompt_prefix} {message.content}"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_prompt(
self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None
) -> str:
"""
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
:param messages: List of PromptMessage to combine.
:param model_name: specific model name.Optional,just to distinguish llama2 and llama3
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
if not messages:
return ""
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AssistantPromptMessage):
messages.append(AssistantPromptMessage(content=""))
text = "".join(self._convert_one_message_to_text(message, model_prefix, model_name) for message in messages)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
def _create_payload(
self,
model: str,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
):
"""
Create payload for bedrock api call depending on model provider
"""
payload = {}
model_prefix = model.split(".")[0]
model_name = model.split(".")[1]
if model_prefix == "ai21":
payload["temperature"] = model_parameters.get("temperature")
payload["topP"] = model_parameters.get("topP")
payload["maxTokens"] = model_parameters.get("maxTokens")
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
if model_parameters.get("presencePenalty"):
payload["presencePenalty"] = {model_parameters.get("presencePenalty")}
if model_parameters.get("frequencyPenalty"):
payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
if model_parameters.get("countPenalty"):
payload["countPenalty"] = {model_parameters.get("countPenalty")}
elif model_prefix == "cohere":
payload = {**model_parameters}
payload["prompt"] = prompt_messages[0].content
payload["stream"] = stream
else:
raise ValueError(f"Got unknown model prefix {model_prefix}")
return payload
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
client_config = Config(region_name=credentials["aws_region"])
runtime_client = boto3.client(
service_name="bedrock-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
model_prefix = model.split(".")[0]
payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream)
# need workaround for ai21 models which doesn't support streaming
if stream and model_prefix != "ai21":
invoke = runtime_client.invoke_model_with_response_stream
else:
invoke = runtime_client.invoke_model
try:
body_jsonstr = json.dumps(payload)
response = invoke(modelId=model, contentType="application/json", accept="*/*", body=body_jsonstr)
except ClientError as ex:
error_code = ex.response["Error"]["Code"]
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise self._map_client_to_invoke_error(error_code, full_error_msg)
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
raise InvokeConnectionError(str(ex))
except UnknownServiceError as ex:
raise InvokeServerUnavailableError(str(ex))
except Exception as ex:
raise InvokeError(str(ex))
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response
"""
response_body = json.loads(response.get("body").read().decode("utf-8"))
finish_reason = response_body.get("error")
if finish_reason is not None:
raise InvokeError(finish_reason)
# get output text and calculate num tokens based on model / provider
model_prefix = model.split(".")[0]
if model_prefix == "ai21":
output = response_body.get("completions")[0].get("data").get("text")
prompt_tokens = len(response_body.get("prompt").get("tokens"))
completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens"))
elif model_prefix == "cohere":
output = response_body.get("generations")[0].get("text")
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, output or "")
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
# construct assistant message from output
assistant_prompt_message = AssistantPromptMessage(content=output)
# calculate usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# construct response
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
)
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
model_prefix = model.split(".")[0]
if model_prefix == "ai21":
response_body = json.loads(response.get("body").read().decode("utf-8"))
content = response_body.get("completions")[0].get("data").get("text")
finish_reason = response_body.get("completions")[0].get("finish_reason")
prompt_tokens = len(response_body.get("prompt").get("tokens"))
completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens"))
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=content), finish_reason=finish_reason, usage=usage
),
)
return
stream = response.get("body")
if not stream:
raise InvokeError("No response body")
index = -1
for event in stream:
chunk = event.get("chunk")
if not chunk:
exception_name = next(iter(event))
full_ex_msg = f"{exception_name}: {event[exception_name]['message']}"
raise self._map_client_to_invoke_error(exception_name, full_ex_msg)
payload = json.loads(chunk.get("bytes").decode())
model_prefix = model.split(".")[0]
if model_prefix == "cohere":
content_delta = payload.get("text")
finish_reason = payload.get("finish_reason")
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=content_delta or "",
)
index += 1
if not finish_reason:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
)
else:
# get num tokens from metrics in last chunk
prompt_tokens = payload["amazon-bedrock-invocationMetrics"]["inputTokenCount"]
completion_tokens = payload["amazon-bedrock-invocationMetrics"]["outputTokenCount"]
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index, message=assistant_prompt_message, finish_reason=finish_reason, usage=usage
),
)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke emd = genai.GenerativeModel(model) error mapping
"""
return {
InvokeConnectionError: [],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [],
}
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
"""
Map client error to invoke error
:param error_code: error code
:param error_msg: error message
:return: invoke error
"""
if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg)
elif error_code in {"ResourceNotFoundException", "ValidationException"}:
return InvokeBadRequestError(error_msg)
elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
return InvokeRateLimitError(error_msg)
elif error_code in {
"ModelTimeoutException",
"ModelErrorException",
"InternalServerException",
"ModelNotReadyException",
}:
return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg)
return InvokeError(error_msg)

@ -1,23 +0,0 @@
model: meta.llama2-13b-chat-v1
label:
en_US: Llama 2 Chat 13B
model_type: llm
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
pricing:
input: '0.00075'
output: '0.00100'
unit: '0.001'
currency: USD

@ -1,23 +0,0 @@
model: meta.llama2-70b-chat-v1
label:
en_US: Llama 2 Chat 70B
model_type: llm
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 2048
min: 1
max: 2048
pricing:
input: '0.00195'
output: '0.00256'
unit: '0.001'
currency: USD

@ -1,25 +0,0 @@
model: meta.llama3-1-405b-instruct-v1:0
label:
en_US: Llama 3.1 405B Instruct
model_type: llm
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
- name: top_p
use_template: top_p
default: 0.9
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.00532'
output: '0.016'
unit: '0.001'
currency: USD

@ -1,25 +0,0 @@
model: meta.llama3-1-70b-instruct-v1:0
label:
en_US: Llama 3.1 Instruct 70B
model_type: llm
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
- name: top_p
use_template: top_p
default: 0.9
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.00265'
output: '0.0035'
unit: '0.001'
currency: USD

@ -1,25 +0,0 @@
model: meta.llama3-1-8b-instruct-v1:0
label:
en_US: Llama 3.1 Instruct 8B
model_type: llm
model_properties:
mode: completion
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
- name: top_p
use_template: top_p
default: 0.9
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.0003'
output: '0.0006'
unit: '0.001'
currency: USD

@ -1,23 +0,0 @@
model: meta.llama3-70b-instruct-v1:0
label:
en_US: Llama 3 Instruct 70B
model_type: llm
model_properties:
mode: completion
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_gen_len
use_template: max_tokens
required: true
default: 512
min: 1
max: 2048
pricing:
input: '0.00265'
output: '0.0035'
unit: '0.00001'
currency: USD

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save