feat: add zhipuai (#1188)
parent
c8bd76cd66
commit
827c97f0d3
@ -1,32 +1,34 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from flask import current_app
|
|
||||||
|
|
||||||
from core.model_providers.error import LLMBadRequestError
|
from core.model_providers.error import LLMBadRequestError
|
||||||
from core.model_providers.providers.base import BaseModelProvider
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
|
||||||
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||||
if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']:
|
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
|
||||||
moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',')
|
|
||||||
|
|
||||||
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||||
and model_provider.provider_name in moderation_providers:
|
and model_provider.provider_name in hosted_config.moderation.providers:
|
||||||
# 2000 text per chunk
|
# 2000 text per chunk
|
||||||
length = 2000
|
length = 2000
|
||||||
chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||||
|
|
||||||
try:
|
max_text_chunks = 32
|
||||||
moderation_result = openai.Moderation.create(input=chunks,
|
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
|
||||||
api_key=current_app.config['HOSTED_OPENAI_API_KEY'])
|
|
||||||
except Exception as ex:
|
for text_chunk in chunks:
|
||||||
logging.exception(ex)
|
try:
|
||||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||||
|
api_key=hosted_model_providers.openai.api_key)
|
||||||
for result in moderation_result.results:
|
except Exception as ex:
|
||||||
if result['flagged'] is True:
|
logging.exception(ex)
|
||||||
return False
|
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||||
|
|
||||||
|
for result in moderation_result.results:
|
||||||
|
if result['flagged'] is True:
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -0,0 +1,22 @@
|
|||||||
|
from core.model_providers.error import LLMBadRequestError
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||||
|
from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuAIEmbedding(BaseEmbedding):
|
||||||
|
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||||
|
credentials = model_provider.get_model_credentials(
|
||||||
|
model_name=name,
|
||||||
|
model_type=self.type
|
||||||
|
)
|
||||||
|
|
||||||
|
client = ZhipuAIEmbeddings(
|
||||||
|
model=name,
|
||||||
|
**credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(model_provider, client, name)
|
||||||
|
|
||||||
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
|
return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}")
|
||||||
@ -0,0 +1,61 @@
|
|||||||
|
from typing import List, Optional, Any
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
from core.model_providers.error import LLMBadRequestError
|
||||||
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
|
from core.model_providers.models.entity.message import PromptMessage
|
||||||
|
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||||
|
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuAIModel(BaseLLM):
|
||||||
|
model_mode: ModelMode = ModelMode.CHAT
|
||||||
|
|
||||||
|
def _init_client(self) -> Any:
|
||||||
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||||
|
return ZhipuAIChatLLM(
|
||||||
|
streaming=self.streaming,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
**self.credentials,
|
||||||
|
**provider_model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run(self, messages: List[PromptMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs) -> LLMResult:
|
||||||
|
"""
|
||||||
|
run predict by prompt messages and stop words.
|
||||||
|
|
||||||
|
:param messages:
|
||||||
|
:param stop:
|
||||||
|
:param callbacks:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
|
return self._client.generate([prompts], stop, callbacks)
|
||||||
|
|
||||||
|
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||||
|
"""
|
||||||
|
get num tokens of prompt messages.
|
||||||
|
|
||||||
|
:param messages:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
|
return max(self._client.get_num_tokens_from_messages(prompts), 0)
|
||||||
|
|
||||||
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||||
|
for k, v in provider_model_kwargs.items():
|
||||||
|
if hasattr(self.client, k):
|
||||||
|
setattr(self.client, k, v)
|
||||||
|
|
||||||
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
|
return LLMBadRequestError(f"ZhipuAI: {str(ex)}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def support_streaming(self):
|
||||||
|
return True
|
||||||
@ -0,0 +1,176 @@
|
|||||||
|
import json
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from langchain.schema import HumanMessage
|
||||||
|
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
|
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
|
||||||
|
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||||
|
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||||
|
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
|
||||||
|
from models.provider import ProviderType, ProviderQuotaType
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuAIProvider(BaseModelProvider):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_name(self):
|
||||||
|
"""
|
||||||
|
Returns the name of a provider.
|
||||||
|
"""
|
||||||
|
return 'zhipuai'
|
||||||
|
|
||||||
|
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||||
|
if model_type == ModelType.TEXT_GENERATION:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'id': 'chatglm_pro',
|
||||||
|
'name': 'chatglm_pro',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': 'chatglm_std',
|
||||||
|
'name': 'chatglm_std',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': 'chatglm_lite',
|
||||||
|
'name': 'chatglm_lite',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': 'chatglm_lite_32k',
|
||||||
|
'name': 'chatglm_lite_32k',
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif model_type == ModelType.EMBEDDINGS:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'id': 'text_embedding',
|
||||||
|
'name': 'text_embedding',
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||||
|
"""
|
||||||
|
Returns the model class.
|
||||||
|
|
||||||
|
:param model_type:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if model_type == ModelType.TEXT_GENERATION:
|
||||||
|
model_class = ZhipuAIModel
|
||||||
|
elif model_type == ModelType.EMBEDDINGS:
|
||||||
|
model_class = ZhipuAIEmbedding
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return model_class
|
||||||
|
|
||||||
|
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||||
|
"""
|
||||||
|
get model parameter rules.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return ModelKwargsRules(
|
||||||
|
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
|
||||||
|
top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1),
|
||||||
|
presence_penalty=KwargRule[float](enabled=False),
|
||||||
|
frequency_penalty=KwargRule[float](enabled=False),
|
||||||
|
max_tokens=KwargRule[int](enabled=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||||
|
"""
|
||||||
|
Validates the given credentials.
|
||||||
|
"""
|
||||||
|
if 'api_key' not in credentials:
|
||||||
|
raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.')
|
||||||
|
|
||||||
|
try:
|
||||||
|
credential_kwargs = {
|
||||||
|
'api_key': credentials['api_key']
|
||||||
|
}
|
||||||
|
|
||||||
|
llm = ZhipuAIChatLLM(
|
||||||
|
temperature=0.01,
|
||||||
|
**credential_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
llm([HumanMessage(content='ping')])
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||||
|
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||||
|
if self.provider.provider_type == ProviderType.CUSTOM.value \
|
||||||
|
or (self.provider.provider_type == ProviderType.SYSTEM.value
|
||||||
|
and self.provider.quota_type == ProviderQuotaType.FREE.value):
|
||||||
|
try:
|
||||||
|
credentials = json.loads(self.provider.encrypted_config)
|
||||||
|
except JSONDecodeError:
|
||||||
|
credentials = {
|
||||||
|
'api_key': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if credentials['api_key']:
|
||||||
|
credentials['api_key'] = encrypter.decrypt_token(
|
||||||
|
self.provider.tenant_id,
|
||||||
|
credentials['api_key']
|
||||||
|
)
|
||||||
|
|
||||||
|
if obfuscated:
|
||||||
|
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def should_deduct_quota(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||||
|
"""
|
||||||
|
check model credentials valid.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param credentials:
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||||
|
credentials: dict) -> dict:
|
||||||
|
"""
|
||||||
|
encrypt model credentials for save.
|
||||||
|
|
||||||
|
:param tenant_id:
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
get credentials for llm use.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param obfuscated:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self.get_provider_credentials(obfuscated)
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
{
|
||||||
|
"support_provider_types": [
|
||||||
|
"system",
|
||||||
|
"custom"
|
||||||
|
],
|
||||||
|
"system_config": {
|
||||||
|
"supported_quota_types": [
|
||||||
|
"free"
|
||||||
|
],
|
||||||
|
"quota_unit": "tokens"
|
||||||
|
},
|
||||||
|
"model_flexibility": "fixed",
|
||||||
|
"price_config": {
|
||||||
|
"chatglm_pro": {
|
||||||
|
"prompt": "0.01",
|
||||||
|
"completion": "0.01",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "RMB"
|
||||||
|
},
|
||||||
|
"chatglm_std": {
|
||||||
|
"prompt": "0.005",
|
||||||
|
"completion": "0.005",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "RMB"
|
||||||
|
},
|
||||||
|
"chatglm_lite": {
|
||||||
|
"prompt": "0.002",
|
||||||
|
"completion": "0.002",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "RMB"
|
||||||
|
},
|
||||||
|
"chatglm_lite_32k": {
|
||||||
|
"prompt": "0.0004",
|
||||||
|
"completion": "0.0004",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "RMB"
|
||||||
|
},
|
||||||
|
"text_embedding": {
|
||||||
|
"completion": "0",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "RMB"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,64 @@
|
|||||||
|
"""Wrapper around ZhipuAI embedding models."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""Wrapper around ZhipuAI embedding models.
|
||||||
|
1024 dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
model: str
|
||||||
|
"""Model name to use."""
|
||||||
|
|
||||||
|
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
values["api_key"] = get_from_dict_or_env(
|
||||||
|
values, "api_key", "ZHIPUAI_API_KEY"
|
||||||
|
)
|
||||||
|
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
|
||||||
|
return values
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Call out to ZhipuAI's embedding endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
response = self.client.invoke(model=self.model, prompt=text)
|
||||||
|
data = response["data"]
|
||||||
|
embeddings.append(data.get('embedding'))
|
||||||
|
|
||||||
|
return [list(map(float, e)) for e in embeddings]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Call out to ZhipuAI's embedding endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
return self.embed_documents([text])[0]
|
||||||
@ -0,0 +1,315 @@
|
|||||||
|
"""Wrapper around ZhipuAI APIs."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import posixpath
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional, Iterator, Sequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
import zhipuai
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
|
||||||
|
from langchain.schema.messages import AIMessageChunk
|
||||||
|
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
|
||||||
|
from pydantic import Extra, root_validator, BaseModel
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
from zhipuai.model_api.api import InvokeType
|
||||||
|
from zhipuai.utils import jwt_token
|
||||||
|
from zhipuai.utils.http_client import post, stream
|
||||||
|
from zhipuai.utils.sse_client import SSEClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuModelAPI(BaseModel):
|
||||||
|
base_url: str
|
||||||
|
api_key: str
|
||||||
|
api_timeout_seconds = 60
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
def invoke(self, **kwargs):
|
||||||
|
url = self._build_api_url(kwargs, InvokeType.SYNC)
|
||||||
|
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||||
|
if not response['success']:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error Code: {response['code']}, Message: {response['msg']} "
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def sse_invoke(self, **kwargs):
|
||||||
|
url = self._build_api_url(kwargs, InvokeType.SSE)
|
||||||
|
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||||
|
return SSEClient(data)
|
||||||
|
|
||||||
|
def _build_api_url(self, kwargs, *path):
|
||||||
|
if kwargs:
|
||||||
|
if "model" not in kwargs:
|
||||||
|
raise Exception("model param missed")
|
||||||
|
model = kwargs.pop("model")
|
||||||
|
else:
|
||||||
|
model = "-"
|
||||||
|
|
||||||
|
return posixpath.join(self.base_url, model, *path)
|
||||||
|
|
||||||
|
def _generate_token(self):
|
||||||
|
if not self.api_key:
|
||||||
|
raise Exception(
|
||||||
|
"api_key not provided, you could provide it."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return jwt_token.generate_token(self.api_key)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(
|
||||||
|
f"Your api_key is invalid, please check it."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuAIChatLLM(BaseChatModel):
|
||||||
|
"""Wrapper around ZhipuAI large language models.
|
||||||
|
To use, you should pass the api_key as a named parameter to the constructor.
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
from core.third_party.langchain.llms.zhipuai import ZhipuAI
|
||||||
|
model = ZhipuAI(model="<model_name>", api_key="my-api-key")
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"api_key": "API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
client: Any = None #: :meta private:
|
||||||
|
model: str = "chatglm_lite"
|
||||||
|
"""Model name to use."""
|
||||||
|
temperature: float = 0.95
|
||||||
|
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||||
|
top_p: float = 0.7
|
||||||
|
"""Total probability mass of tokens to consider at each step."""
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the response or return it all at once."""
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
|
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
values["api_key"] = get_from_dict_or_env(
|
||||||
|
values, "api_key", "ZHIPUAI_API_KEY"
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'test' in values['base_url']:
|
||||||
|
values['model'] = 'chatglm_130b_test'
|
||||||
|
|
||||||
|
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling OpenAI API."""
|
||||||
|
return {
|
||||||
|
"model": self.model,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return self._default_params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "zhipuai"
|
||||||
|
|
||||||
|
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
|
||||||
|
role = _dict["role"]
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=_dict["content"])
|
||||||
|
elif role == "assistant":
|
||||||
|
return AIMessage(content=_dict["content"])
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=_dict["content"])
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=_dict["content"], role=role)
|
||||||
|
|
||||||
|
def _create_message_dicts(
|
||||||
|
self, messages: List[BaseMessage]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
dict_messages = []
|
||||||
|
for m in messages:
|
||||||
|
message = self._convert_message_to_dict(m)
|
||||||
|
if dict_messages:
|
||||||
|
previous_message = dict_messages[-1]
|
||||||
|
if previous_message['role'] == message['role']:
|
||||||
|
dict_messages[-1]['content'] += f"\n{message['content']}"
|
||||||
|
else:
|
||||||
|
dict_messages.append(message)
|
||||||
|
else:
|
||||||
|
dict_messages.append(message)
|
||||||
|
|
||||||
|
return dict_messages
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
llm_output: Optional[Dict] = None
|
||||||
|
for chunk in self._stream(
|
||||||
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
if chunk.generation_info is not None \
|
||||||
|
and 'token_usage' in chunk.generation_info:
|
||||||
|
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
||||||
|
continue
|
||||||
|
|
||||||
|
if generation is None:
|
||||||
|
generation = chunk
|
||||||
|
else:
|
||||||
|
generation += chunk
|
||||||
|
assert generation is not None
|
||||||
|
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||||
|
else:
|
||||||
|
message_dicts = self._create_message_dicts(messages)
|
||||||
|
request = self._default_params
|
||||||
|
request["prompt"] = message_dicts
|
||||||
|
request.update(kwargs)
|
||||||
|
response = self.client.invoke(**request)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
message_dicts = self._create_message_dicts(messages)
|
||||||
|
request = self._default_params
|
||||||
|
request["prompt"] = message_dicts
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
for event in self.client.sse_invoke(incremental=True, **request).events():
|
||||||
|
if event.event == "add":
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=event.data))
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(event.data)
|
||||||
|
elif event.event == "error" or event.event == "interrupted":
|
||||||
|
raise ValueError(
|
||||||
|
f"{event.data}"
|
||||||
|
)
|
||||||
|
elif event.event == "finish":
|
||||||
|
meta = json.loads(event.meta)
|
||||||
|
token_usage = meta['usage']
|
||||||
|
if token_usage is not None:
|
||||||
|
if 'prompt_tokens' not in token_usage:
|
||||||
|
token_usage['prompt_tokens'] = 0
|
||||||
|
if 'completion_tokens' not in token_usage:
|
||||||
|
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||||
|
|
||||||
|
yield ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(content=event.data),
|
||||||
|
generation_info=dict({'token_usage': token_usage})
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
|
||||||
|
data = response["data"]
|
||||||
|
generations = []
|
||||||
|
for res in data["choices"]:
|
||||||
|
message = self._convert_dict_to_message(res)
|
||||||
|
gen = ChatGeneration(
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
generations.append(gen)
|
||||||
|
token_usage = data.get("usage")
|
||||||
|
if token_usage is not None:
|
||||||
|
if 'prompt_tokens' not in token_usage:
|
||||||
|
token_usage['prompt_tokens'] = 0
|
||||||
|
if 'completion_tokens' not in token_usage:
|
||||||
|
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||||
|
|
||||||
|
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
# def get_token_ids(self, text: str) -> List[int]:
|
||||||
|
# """Return the ordered ids of the tokens in a text.
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# text: The string input to tokenize.
|
||||||
|
#
|
||||||
|
# Returns:
|
||||||
|
# A list of ids corresponding to the tokens in the text, in order they occur
|
||||||
|
# in the text.
|
||||||
|
# """
|
||||||
|
# from core.third_party.transformers.Token import ChatGLMTokenizer
|
||||||
|
#
|
||||||
|
# tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b")
|
||||||
|
# return tokenizer.encode(text)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
"""Get the number of tokens in the messages.
|
||||||
|
|
||||||
|
Useful for checking if an input will fit in a model's context window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The message inputs to tokenize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sum of the number of tokens across the messages.
|
||||||
|
"""
|
||||||
|
return sum([self.get_num_tokens(m.content) for m in messages])
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
overall_token_usage: dict = {}
|
||||||
|
for output in llm_outputs:
|
||||||
|
if output is None:
|
||||||
|
# Happens in streaming
|
||||||
|
continue
|
||||||
|
token_usage = output["token_usage"]
|
||||||
|
for k, v in token_usage.items():
|
||||||
|
if k in overall_token_usage:
|
||||||
|
overall_token_usage[k] += v
|
||||||
|
else:
|
||||||
|
overall_token_usage[k] = v
|
||||||
|
return {"token_usage": overall_token_usage, "model_name": self.model}
|
||||||
@ -0,0 +1,50 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
|
||||||
|
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
|
||||||
|
from models.provider import Provider, ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_provider(valid_api_key):
|
||||||
|
return Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name='zhipuai',
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config=json.dumps({
|
||||||
|
'api_key': valid_api_key
|
||||||
|
}),
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_embedding_model():
|
||||||
|
model_name = 'text_embedding'
|
||||||
|
valid_api_key = os.environ['ZHIPUAI_API_KEY']
|
||||||
|
provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key))
|
||||||
|
return ZhipuAIEmbedding(
|
||||||
|
model_provider=provider,
|
||||||
|
name=model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||||
|
return encrypted_api_key
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_embedding(mock_decrypt):
|
||||||
|
embedding_model = get_mock_embedding_model()
|
||||||
|
rst = embedding_model.client.embed_query('test')
|
||||||
|
assert isinstance(rst, list)
|
||||||
|
assert len(rst) == 1024
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_doc_embedding(mock_decrypt):
|
||||||
|
embedding_model = get_mock_embedding_model()
|
||||||
|
rst = embedding_model.client.embed_documents(['test', 'test2'])
|
||||||
|
assert isinstance(rst, list)
|
||||||
|
assert len(rst[0]) == 1024
|
||||||
@ -0,0 +1,79 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
||||||
|
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||||
|
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||||
|
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
|
||||||
|
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
|
||||||
|
from models.provider import Provider, ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_provider(valid_api_key):
|
||||||
|
return Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name='zhipuai',
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config=json.dumps({
|
||||||
|
'api_key': valid_api_key
|
||||||
|
}),
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_model(model_name: str, streaming: bool = False):
|
||||||
|
model_kwargs = ModelKwargs(
|
||||||
|
temperature=0.01,
|
||||||
|
)
|
||||||
|
valid_api_key = os.environ['ZHIPUAI_API_KEY']
|
||||||
|
model_provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key))
|
||||||
|
return ZhipuAIModel(
|
||||||
|
model_provider=model_provider,
|
||||||
|
name=model_name,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||||
|
return encrypted_api_key
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_chat_get_num_tokens(mock_decrypt):
|
||||||
|
model = get_mock_model('chatglm_lite')
|
||||||
|
rst = model.get_num_tokens([
|
||||||
|
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
|
||||||
|
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||||
|
])
|
||||||
|
assert rst > 0
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_chat_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
|
model = get_mock_model('chatglm_lite')
|
||||||
|
messages = [
|
||||||
|
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||||
|
]
|
||||||
|
rst = model.run(
|
||||||
|
messages,
|
||||||
|
)
|
||||||
|
assert len(rst.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_chat_stream_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
|
model = get_mock_model('chatglm_lite', streaming=True)
|
||||||
|
messages = [
|
||||||
|
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||||
|
]
|
||||||
|
rst = model.run(
|
||||||
|
messages
|
||||||
|
)
|
||||||
|
assert len(rst.content) > 0
|
||||||
@ -0,0 +1,88 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
import json
|
||||||
|
|
||||||
|
from langchain.schema import ChatResult, ChatGeneration, AIMessage
|
||||||
|
|
||||||
|
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||||
|
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
|
||||||
|
from models.provider import ProviderType, Provider
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDER_NAME = 'zhipuai'
|
||||||
|
MODEL_PROVIDER_CLASS = ZhipuAIProvider
|
||||||
|
VALIDATE_CREDENTIAL = {
|
||||||
|
'api_key': 'valid_key',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||||
|
return f'encrypted_{encrypt_key}'
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||||
|
return encrypted_key.replace('encrypted_', '')
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||||
|
mocker.patch('core.third_party.langchain.llms.zhipuai_llm.ZhipuAIChatLLM._generate',
|
||||||
|
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
|
||||||
|
|
||||||
|
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_provider_credentials_valid_or_raise_invalid():
|
||||||
|
# raise CredentialsValidateFailedError if api_key is not in credentials
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
|
||||||
|
|
||||||
|
credential = VALIDATE_CREDENTIAL.copy()
|
||||||
|
credential['api_key'] = 'invalid_key'
|
||||||
|
|
||||||
|
# raise CredentialsValidateFailedError if api_key is invalid
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||||
|
def test_encrypt_credentials(mock_encrypt):
|
||||||
|
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
|
||||||
|
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_get_credentials_custom(mock_decrypt):
|
||||||
|
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||||
|
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
|
||||||
|
|
||||||
|
provider = Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name=PROVIDER_NAME,
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config=json.dumps(encrypted_credential),
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||||
|
result = model_provider.get_provider_credentials()
|
||||||
|
assert result['api_key'] == 'valid_key'
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_get_credentials_obfuscated(mock_decrypt):
|
||||||
|
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||||
|
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
|
||||||
|
|
||||||
|
provider = Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name=PROVIDER_NAME,
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config=json.dumps(encrypted_credential),
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||||
|
result = model_provider.get_provider_credentials(obfuscated=True)
|
||||||
|
middle_token = result['api_key'][6:-2]
|
||||||
|
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
|
||||||
|
assert all(char == '*' for char in middle_token)
|
||||||
Loading…
Reference in New Issue