feat: server multi models support (#799)
parent
d8b712b325
commit
5fa2161b05
@ -0,0 +1,53 @@
|
||||
import logging
|
||||
|
||||
import stripe
|
||||
from flask import request, current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from services.provider_checkout_service import ProviderCheckoutService
|
||||
|
||||
|
||||
class StripeWebhookApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
payload = request.data
|
||||
sig_header = request.headers.get('STRIPE_SIGNATURE')
|
||||
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
return 'Invalid payload', 400
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
return 'Invalid signature', 400
|
||||
|
||||
# Handle the checkout.session.completed event
|
||||
if event['type'] == 'checkout.session.completed':
|
||||
logging.debug(event['data']['object']['id'])
|
||||
logging.debug(event['data']['object']['amount_subtotal'])
|
||||
logging.debug(event['data']['object']['currency'])
|
||||
logging.debug(event['data']['object']['payment_intent'])
|
||||
logging.debug(event['data']['object']['payment_status'])
|
||||
logging.debug(event['data']['object']['metadata'])
|
||||
|
||||
# Fulfill the purchase...
|
||||
provider_checkout_service = ProviderCheckoutService()
|
||||
|
||||
try:
|
||||
provider_checkout_service.fulfill_provider_order(event)
|
||||
except Exception as e:
|
||||
logging.debug(str(e))
|
||||
return 'success', 200
|
||||
|
||||
return 'success', 200
|
||||
|
||||
|
||||
api.add_resource(StripeWebhookApi, '/webhook/stripe')
|
||||
@ -0,0 +1,108 @@
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from models.provider import ProviderType
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
class DefaultModelApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
provider_service = ProviderService()
|
||||
default_model = provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(
|
||||
tenant_id,
|
||||
default_model.provider_name
|
||||
)
|
||||
|
||||
if not model_provider:
|
||||
return {
|
||||
'model_name': default_model.model_name,
|
||||
'model_type': default_model.model_type,
|
||||
'model_provider': {
|
||||
'provider_name': default_model.provider_name
|
||||
}
|
||||
}
|
||||
|
||||
provider = model_provider.provider
|
||||
rst = {
|
||||
'model_name': default_model.model_name,
|
||||
'model_type': default_model.model_type,
|
||||
'model_provider': {
|
||||
'provider_name': provider.provider_name,
|
||||
'provider_type': provider.provider_type
|
||||
}
|
||||
}
|
||||
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
rst['model_provider']['quota_type'] = provider.quota_type
|
||||
rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
|
||||
rst['model_provider']['quota_limit'] = provider.quota_limit
|
||||
rst['model_provider']['quota_used'] = provider.quota_used
|
||||
|
||||
return rst
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
|
||||
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_service.update_default_model_of_model_type(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=args['model_type'],
|
||||
provider_name=args['provider_name'],
|
||||
model_name=args['model_name']
|
||||
)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ValidModelApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type):
|
||||
ModelType.value_of(model_type)
|
||||
|
||||
provider_service = ProviderService()
|
||||
valid_models = provider_service.get_valid_model_list(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
return valid_models
|
||||
|
||||
|
||||
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
|
||||
api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')
|
||||
@ -0,0 +1,130 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from models.provider import ProviderType
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
class ProviderListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
"""
|
||||
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
|
||||
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
|
||||
rest is replaced by * and the last two bits are displayed in plaintext
|
||||
|
||||
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
|
||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
||||
"""
|
||||
|
||||
provider_service = ProviderService()
|
||||
provider_info_list = provider_service.get_provider_list(tenant_id)
|
||||
|
||||
provider_list = [
|
||||
{
|
||||
'provider_name': p['provider_name'],
|
||||
'provider_type': p['provider_type'],
|
||||
'is_valid': p['is_valid'],
|
||||
'last_used': p['last_used'],
|
||||
'is_enabled': p['is_valid'],
|
||||
**({
|
||||
'quota_type': p['quota_type'],
|
||||
'quota_limit': p['quota_limit'],
|
||||
'quota_used': p['quota_used']
|
||||
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
|
||||
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
|
||||
if p['config'] else None
|
||||
}
|
||||
for name, provider_info in provider_info_list.items()
|
||||
for p in provider_info['providers']
|
||||
]
|
||||
|
||||
return provider_list
|
||||
|
||||
|
||||
class ProviderTokenApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if provider == 'openai':
|
||||
args['token'] = {
|
||||
'openai_api_key': args['token']
|
||||
}
|
||||
|
||||
provider_service = ProviderService()
|
||||
try:
|
||||
provider_service.save_custom_provider_config(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_name=provider,
|
||||
config=args['token']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
|
||||
class ProviderTokenValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('token', required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
provider_service = ProviderService()
|
||||
|
||||
if provider == 'openai':
|
||||
args['token'] = {
|
||||
'openai_api_key': args['token']
|
||||
}
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
provider_service.custom_provider_config_validate(
|
||||
provider_name=provider,
|
||||
config=args['token']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
|
||||
endpoint='workspaces_current_providers_token') # PUT for updating provider token
|
||||
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
|
||||
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
|
||||
|
||||
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
|
||||
@ -1,36 +0,0 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import langchain
|
||||
from flask import Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.prompt.prompt_template import OneLineFormatter
|
||||
|
||||
|
||||
class HostedOpenAICredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedAnthropicCredential(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
class HostedLLMCredentials(BaseModel):
|
||||
openai: Optional[HostedOpenAICredential] = None
|
||||
anthropic: Optional[HostedAnthropicCredential] = None
|
||||
|
||||
|
||||
hosted_llm_credentials = HostedLLMCredentials()
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
langchain.verbose = True
|
||||
|
||||
if app.config.get("OPENAI_API_KEY"):
|
||||
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
||||
|
||||
if app.config.get("ANTHROPIC_API_KEY"):
|
||||
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))
|
||||
@ -0,0 +1,162 @@
|
||||
import re
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
```"""
|
||||
|
||||
|
||||
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
model_instance: BaseLLM
|
||||
dataset_tools: Sequence[BaseTool]
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
if len(self.dataset_tools) == 0:
|
||||
return AgentFinish(return_values={"output": ''}, log='')
|
||||
elif len(self.dataset_tools) == 1:
|
||||
tool = next(iter(self.dataset_tools))
|
||||
tool = cast(DatasetRetrieverTool, tool)
|
||||
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
|
||||
try:
|
||||
return self.output_parser.parse(full_output)
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
unique_tool_names = set(tool.name for tool in tools)
|
||||
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
dataset_tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1,109 +0,0 @@
|
||||
from _decimal import Decimal
|
||||
|
||||
models = {
|
||||
'claude-instant-1': 'anthropic', # 100,000 tokens
|
||||
'claude-2': 'anthropic', # 100,000 tokens
|
||||
'gpt-4': 'openai', # 8,192 tokens
|
||||
'gpt-4-32k': 'openai', # 32,768 tokens
|
||||
'gpt-3.5-turbo': 'openai', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k': 'openai', # 16384 tokens
|
||||
'text-davinci-003': 'openai', # 4,097 tokens
|
||||
'text-davinci-002': 'openai', # 4,097 tokens
|
||||
'text-curie-001': 'openai', # 2,049 tokens
|
||||
'text-babbage-001': 'openai', # 2,049 tokens
|
||||
'text-ada-001': 'openai', # 2,049 tokens
|
||||
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
|
||||
'whisper-1': 'openai'
|
||||
}
|
||||
|
||||
max_context_token_length = {
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
'text-davinci-002': 4097,
|
||||
'text-curie-001': 2049,
|
||||
'text-babbage-001': 2049,
|
||||
'text-ada-001': 2049,
|
||||
'text-embedding-ada-002': 8191,
|
||||
}
|
||||
|
||||
models_by_mode = {
|
||||
'chat': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
],
|
||||
'completion': [
|
||||
'claude-instant-1', # 100,000 tokens
|
||||
'claude-2', # 100,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
'text-davinci-003', # 4,097 tokens
|
||||
'text-davinci-002' # 4,097 tokens
|
||||
'text-curie-001', # 2,049 tokens
|
||||
'text-babbage-001', # 2,049 tokens
|
||||
'text-ada-001' # 2,049 tokens
|
||||
],
|
||||
'embedding': [
|
||||
'text-embedding-ada-002' # 8191 tokens, 1536 dimensions
|
||||
]
|
||||
}
|
||||
|
||||
model_currency = 'USD'
|
||||
|
||||
model_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': Decimal('0.00163'),
|
||||
'completion': Decimal('0.00551'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': Decimal('0.01102'),
|
||||
'completion': Decimal('0.03268'),
|
||||
},
|
||||
'gpt-4': {
|
||||
'prompt': Decimal('0.03'),
|
||||
'completion': Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': Decimal('0.06'),
|
||||
'completion': Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': Decimal('0.0015'),
|
||||
'completion': Decimal('0.002')
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
'prompt': Decimal('0.003'),
|
||||
'completion': Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': Decimal('0.02'),
|
||||
'completion': Decimal('0.02')
|
||||
},
|
||||
'text-curie-001': {
|
||||
'prompt': Decimal('0.002'),
|
||||
'completion': Decimal('0.002')
|
||||
},
|
||||
'text-babbage-001': {
|
||||
'prompt': Decimal('0.0005'),
|
||||
'completion': Decimal('0.0005')
|
||||
},
|
||||
'text-ada-001': {
|
||||
'prompt': Decimal('0.0004'),
|
||||
'completion': Decimal('0.0004')
|
||||
},
|
||||
'text-embedding-ada-002': {
|
||||
'usage': Decimal('0.0001'),
|
||||
}
|
||||
}
|
||||
|
||||
agent_model_name = 'text-davinci-003'
|
||||
@ -0,0 +1,20 @@
|
||||
import base64
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
def obfuscated_token(token: str):
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
|
||||
def decrypt_token(tenant_id: str, token: str):
|
||||
return rsa.decrypt(base64.b64decode(token), tenant_id)
|
||||
@ -1,148 +0,0 @@
|
||||
from typing import Union, Optional, List
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
|
||||
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
|
||||
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
|
||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||
from models.provider import ProviderType, ProviderName
|
||||
|
||||
|
||||
class LLMBuilder:
|
||||
"""
|
||||
This class handles the following logic:
|
||||
1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
|
||||
2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
|
||||
OPENAI_API_TYPE=azure
|
||||
OPENAI_API_VERSION=2022-12-01
|
||||
OPENAI_API_BASE=https://your-resource-name.openai.azure.com
|
||||
OPENAI_API_KEY=<your Azure OpenAI API key>
|
||||
3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
|
||||
4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
|
||||
5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
|
||||
6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
|
||||
7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
|
||||
8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
provider = cls.get_default_provider(tenant_id, model_name)
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||
|
||||
llm_cls = None
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
llm_cls = StreamableChatOpenAI
|
||||
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||
llm_cls = StreamableAzureChatOpenAI
|
||||
elif provider == ProviderName.ANTHROPIC.value:
|
||||
llm_cls = StreamableChatAnthropic
|
||||
elif mode == 'completion':
|
||||
if provider == ProviderName.OPENAI.value:
|
||||
llm_cls = StreamableOpenAI
|
||||
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||
llm_cls = StreamableAzureOpenAI
|
||||
|
||||
if not llm_cls:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
model_kwargs = {
|
||||
'model_name': model_name,
|
||||
'temperature': kwargs.get('temperature', 0),
|
||||
'max_tokens': kwargs.get('max_tokens', 256),
|
||||
'top_p': kwargs.get('top_p', 1),
|
||||
'frequency_penalty': kwargs.get('frequency_penalty', 0),
|
||||
'presence_penalty': kwargs.get('presence_penalty', 0),
|
||||
'callbacks': kwargs.get('callbacks', None),
|
||||
'streaming': kwargs.get('streaming', False),
|
||||
}
|
||||
|
||||
model_kwargs.update(model_credentials)
|
||||
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
|
||||
|
||||
return llm_cls(**model_kwargs)
|
||||
|
||||
@classmethod
|
||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
model_name = model.get("name")
|
||||
completion_params = model.get("completion_params", {})
|
||||
|
||||
return cls.to_llm(
|
||||
tenant_id=tenant_id,
|
||||
model_name=model_name,
|
||||
temperature=completion_params.get('temperature', 0),
|
||||
max_tokens=completion_params.get('max_tokens', 256),
|
||||
top_p=completion_params.get('top_p', 0),
|
||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||
presence_penalty=completion_params.get('presence_penalty', 0.1),
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mode_by_model(cls, model_name: str) -> str:
|
||||
if not model_name:
|
||||
raise ValueError(f"empty model name is not supported.")
|
||||
|
||||
if model_name in llm_constant.models_by_mode['chat']:
|
||||
return "chat"
|
||||
elif model_name in llm_constant.models_by_mode['completion']:
|
||||
return "completion"
|
||||
else:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
@classmethod
|
||||
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
|
||||
"""
|
||||
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
|
||||
Raises an exception if the model_name is not found or if the provider is not found.
|
||||
"""
|
||||
if not model_name:
|
||||
raise Exception('model name not found')
|
||||
#
|
||||
# if model_name not in llm_constant.models:
|
||||
# raise Exception('model {} not found'.format(model_name))
|
||||
|
||||
# model_provider = llm_constant.models[model_name]
|
||||
|
||||
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
|
||||
return provider_service.get_credentials(model_name)
|
||||
|
||||
@classmethod
|
||||
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
|
||||
provider_name = llm_constant.models[model_name]
|
||||
|
||||
if provider_name == 'openai':
|
||||
# get the default provider (openai / azure_openai) for the tenant
|
||||
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
|
||||
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||
|
||||
provider = None
|
||||
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
|
||||
provider = azure_openai_provider
|
||||
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
|
||||
provider = azure_openai_provider
|
||||
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {provider_name} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
provider_name = provider.provider_name
|
||||
|
||||
return provider_name
|
||||
@ -1,15 +0,0 @@
|
||||
import openai
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class Moderation:
|
||||
|
||||
def __init__(self, provider: str, api_key: str):
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
|
||||
if self.provider == ProviderName.OPENAI.value:
|
||||
self.client = openai.Moderation
|
||||
|
||||
def moderate(self, text):
|
||||
return self.client.create(input=text, api_key=self.api_key)
|
||||
@ -1,138 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import anthropic
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName, ProviderType
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
'name': 'claude-2',
|
||||
},
|
||||
]
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
return self.get_provider_api_key(model_id=model_id)
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.ANTHROPIC
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = {
|
||||
'anthropic_api_key': ''
|
||||
}
|
||||
|
||||
if obfuscated:
|
||||
if not config.get('anthropic_api_key'):
|
||||
config = {
|
||||
'anthropic_api_key': ''
|
||||
}
|
||||
|
||||
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
|
||||
return config
|
||||
|
||||
return config
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
Returns the encrypted token.
|
||||
"""
|
||||
return json.dumps({
|
||||
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
|
||||
})
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
"""
|
||||
Returns the decrypted token.
|
||||
"""
|
||||
config = json.loads(token)
|
||||
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
# check OpenAI / Azure OpenAI credential is valid
|
||||
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
|
||||
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||
|
||||
provider = None
|
||||
if openai_provider:
|
||||
provider = openai_provider
|
||||
elif azure_openai_provider:
|
||||
provider = azure_openai_provider
|
||||
|
||||
if not provider:
|
||||
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||
if quota_used >= quota_limit:
|
||||
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
|
||||
f"please configure OpenAI or Azure OpenAI provider first.")
|
||||
|
||||
try:
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError('Config must be a object.')
|
||||
|
||||
if 'anthropic_api_key' not in config:
|
||||
raise ValueError('anthropic_api_key must be provided.')
|
||||
|
||||
chat_llm = ChatAnthropic(
|
||||
model='claude-instant-1',
|
||||
anthropic_api_key=config['anthropic_api_key'],
|
||||
max_tokens_to_sample=10,
|
||||
temperature=0,
|
||||
default_request_timeout=60
|
||||
)
|
||||
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="ping"
|
||||
)
|
||||
]
|
||||
|
||||
chat_llm(messages)
|
||||
except anthropic.APIConnectionError as ex:
|
||||
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
|
||||
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
|
||||
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
|
||||
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
|
||||
except Exception as ex:
|
||||
logging.exception('Anthropic config validation failed')
|
||||
raise ex
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}
|
||||
@ -1,145 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
class AzureProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
|
||||
return []
|
||||
|
||||
def check_embedding_model(self, credentials: Optional[dict] = None):
|
||||
credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
|
||||
try:
|
||||
result = openai.Embedding.create(input=['test'],
|
||||
engine='text-embedding-ada-002',
|
||||
timeout=60,
|
||||
api_key=str(credentials.get('openai_api_key')),
|
||||
api_base=str(credentials.get('openai_api_base')),
|
||||
api_type='azure',
|
||||
api_version=str(credentials.get('openai_api_version')))["data"][0][
|
||||
"embedding"]
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise AzureAuthenticationError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
raise AzureRequestFailedError(
|
||||
'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
|
||||
except openai.error.InvalidRequestError as e:
|
||||
if e.http_status == 404:
|
||||
raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
|
||||
"deployment name is exists in Azure AI")
|
||||
else:
|
||||
raise AzureRequestFailedError(
|
||||
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
|
||||
except openai.error.OpenAIError as e:
|
||||
raise AzureRequestFailedError(
|
||||
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
|
||||
|
||||
if not isinstance(result, list):
|
||||
raise AzureRequestFailedError('Failed to request Azure OpenAI.')
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Azure OpenAI as a dictionary.
|
||||
"""
|
||||
config = self.get_provider_api_key(model_id=model_id)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
|
||||
if model_id == 'text-embedding-ada-002':
|
||||
config['deployment'] = model_id.replace('.', '') if model_id else None
|
||||
config['chunk_size'] = 16
|
||||
else:
|
||||
config['deployment_name'] = model_id.replace('.', '') if model_id else None
|
||||
return config
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.AZURE_OPENAI
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
if obfuscated:
|
||||
if not config.get('openai_api_key'):
|
||||
config = {
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': '',
|
||||
'openai_api_key': ''
|
||||
}
|
||||
|
||||
config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
|
||||
return config
|
||||
|
||||
return config
|
||||
|
||||
def get_token_type(self):
|
||||
return dict
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
try:
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError('Config must be a object.')
|
||||
|
||||
if 'openai_api_version' not in config:
|
||||
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
|
||||
|
||||
self.check_embedding_model(credentials=config)
|
||||
except ValidateFailedError as e:
|
||||
raise e
|
||||
except AzureAuthenticationError:
|
||||
raise ValidateFailedError('Validation failed, please check your API Key.')
|
||||
except AzureRequestFailedError as ex:
|
||||
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
|
||||
except Exception as ex:
|
||||
logging.exception('Azure OpenAI Credentials validation failed')
|
||||
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
"""
|
||||
Returns the encrypted token.
|
||||
"""
|
||||
return json.dumps({
|
||||
'openai_api_type': 'azure',
|
||||
'openai_api_version': AZURE_OPENAI_API_VERSION,
|
||||
'openai_api_base': config['openai_api_base'],
|
||||
'openai_api_key': self.encrypt_token(config['openai_api_key'])
|
||||
})
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
"""
|
||||
Returns the decrypted token.
|
||||
"""
|
||||
config = json.loads(token)
|
||||
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
|
||||
return config
|
||||
|
||||
|
||||
class AzureAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AzureRequestFailedError(Exception):
|
||||
pass
|
||||
@ -1,132 +0,0 @@
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
from models.provider import Provider, ProviderType, ProviderName
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
|
||||
"""
|
||||
provider = self.get_provider(only_custom)
|
||||
if not provider:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||
|
||||
if model_id and model_id == 'gpt-4':
|
||||
raise ModelCurrentlyNotSupportError()
|
||||
|
||||
if quota_used >= quota_limit:
|
||||
raise QuotaExceededError()
|
||||
|
||||
return self.get_hosted_credentials()
|
||||
else:
|
||||
return self.get_decrypted_token(provider.encrypted_config)
|
||||
|
||||
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||
"""
|
||||
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
|
||||
|
||||
@classmethod
|
||||
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
|
||||
Provider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and provider_name.
|
||||
If both CUSTOM and System providers exist.
|
||||
"""
|
||||
query = db.session.query(Provider).filter(
|
||||
Provider.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
if provider_name:
|
||||
query = query.filter(Provider.provider_name == provider_name)
|
||||
|
||||
if only_custom:
|
||||
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
|
||||
|
||||
providers = query.order_by(Provider.provider_type.asc()).all()
|
||||
|
||||
for provider in providers:
|
||||
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
|
||||
return provider
|
||||
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
"""
|
||||
Returns the provider configs.
|
||||
"""
|
||||
try:
|
||||
config = self.get_provider_api_key(only_custom=only_custom)
|
||||
except:
|
||||
config = ''
|
||||
|
||||
if obfuscated:
|
||||
return self.obfuscated_token(config)
|
||||
|
||||
return config
|
||||
|
||||
def obfuscated_token(self, token: str):
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
|
||||
def get_token_type(self):
|
||||
return str
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
return self.encrypt_token(config)
|
||||
|
||||
def get_decrypted_token(self, token: str):
|
||||
return self.decrypt_token(token)
|
||||
|
||||
def encrypt_token(self, token):
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
def decrypt_token(self, token):
|
||||
return rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def config_validate(self, config: str):
|
||||
raise NotImplementedError
|
||||
@ -1,2 +0,0 @@
|
||||
class ValidateFailedError(Exception):
|
||||
description = "Provider Validate failed"
|
||||
@ -1,22 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class HuggingfaceProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
# todo
|
||||
return []
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
|
||||
"""
|
||||
return {
|
||||
'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.HUGGINGFACEHUB
|
||||
@ -1,53 +0,0 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.llm.provider.anthropic_provider import AnthropicProvider
|
||||
from core.llm.provider.azure_provider import AzureProvider
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.huggingface_provider import HuggingfaceProvider
|
||||
from core.llm.provider.openai_provider import OpenAIProvider
|
||||
from models.provider import Provider
|
||||
|
||||
|
||||
class LLMProviderService:
|
||||
|
||||
def __init__(self, tenant_id: str, provider_name: str):
|
||||
self.provider = self.init_provider(tenant_id, provider_name)
|
||||
|
||||
def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
|
||||
if provider_name == 'openai':
|
||||
return OpenAIProvider(tenant_id)
|
||||
elif provider_name == 'azure_openai':
|
||||
return AzureProvider(tenant_id)
|
||||
elif provider_name == 'anthropic':
|
||||
return AnthropicProvider(tenant_id)
|
||||
elif provider_name == 'huggingface':
|
||||
return HuggingfaceProvider(tenant_id)
|
||||
else:
|
||||
raise Exception('provider {} not found'.format(provider_name))
|
||||
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
return self.provider.get_models(model_id)
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
return self.provider.get_credentials(model_id)
|
||||
|
||||
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
|
||||
|
||||
def get_provider_db_record(self) -> Optional[Provider]:
|
||||
return self.provider.get_provider()
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
|
||||
:param config:
|
||||
:raises: ValidateFailedError
|
||||
"""
|
||||
return self.provider.config_validate(config)
|
||||
|
||||
def get_token_type(self):
|
||||
return self.provider.get_token_type()
|
||||
|
||||
def get_encrypted_token(self, config: Union[dict | str]):
|
||||
return self.provider.get_encrypted_token(config)
|
||||
@ -1,55 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import openai
|
||||
from openai.error import AuthenticationError, OpenAIError
|
||||
|
||||
from core import hosted_llm_credentials
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
from core.llm.moderation import Moderation
|
||||
from core.llm.provider.base import BaseProvider
|
||||
from core.llm.provider.errors import ValidateFailedError
|
||||
from models.provider import ProviderName
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||
credentials = self.get_credentials(model_id)
|
||||
response = openai.Model.list(**credentials)
|
||||
|
||||
return [{
|
||||
'id': model['id'],
|
||||
'name': model['id'],
|
||||
} for model in response['data']]
|
||||
|
||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Returns the credentials for the given tenant_id and provider_name.
|
||||
"""
|
||||
return {
|
||||
'openai_api_key': self.get_provider_api_key(model_id=model_id)
|
||||
}
|
||||
|
||||
def get_provider_name(self):
|
||||
return ProviderName.OPENAI
|
||||
|
||||
def config_validate(self, config: Union[dict | str]):
|
||||
"""
|
||||
Validates the given config.
|
||||
"""
|
||||
try:
|
||||
Moderation(self.get_provider_name().value, config).moderate('test')
|
||||
except (AuthenticationError, OpenAIError) as ex:
|
||||
raise ValidateFailedError(str(ex))
|
||||
except Exception as ex:
|
||||
logging.exception('OpenAI config validation failed')
|
||||
raise ex
|
||||
|
||||
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||
raise ProviderTokenNotInitError(
|
||||
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
|
||||
return hosted_llm_credentials.openai.api_key
|
||||
@ -1,62 +0,0 @@
|
||||
from typing import List, Optional, Any, Dict
|
||||
|
||||
from httpx import Timeout
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
|
||||
|
||||
|
||||
class StreamableChatAnthropic(ChatAnthropic):
|
||||
"""
|
||||
Wrapper around Anthropic's large language model.
|
||||
"""
|
||||
|
||||
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
|
||||
|
||||
@root_validator()
|
||||
def prepare_params(cls, values: Dict) -> Dict:
|
||||
values['model_name'] = values.get('model')
|
||||
values['max_tokens'] = values.get('max_tokens_to_sample')
|
||||
return values
|
||||
|
||||
@handle_anthropic_exceptions
|
||||
def generate(
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
params['model'] = params.get('model_name')
|
||||
del params['model_name']
|
||||
|
||||
params['max_tokens_to_sample'] = params.get('max_tokens')
|
||||
del params['max_tokens']
|
||||
|
||||
del params['frequency_penalty']
|
||||
del params['presence_penalty']
|
||||
|
||||
return params
|
||||
|
||||
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} {message.content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{self.AI_PROMPT} {message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"<admin>{message.content}</admin>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
@ -1,41 +0,0 @@
|
||||
import decimal
|
||||
from typing import Optional
|
||||
|
||||
import tiktoken
|
||||
|
||||
from core.constant import llm_constant
|
||||
|
||||
|
||||
class TokenCalculator:
|
||||
@classmethod
|
||||
def get_num_tokens(cls, model_name: str, text: str):
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
@classmethod
|
||||
def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
|
||||
if model_name in llm_constant.models_by_mode['embedding']:
|
||||
unit_price = llm_constant.model_prices[model_name]['usage']
|
||||
elif text_type == 'prompt':
|
||||
unit_price = llm_constant.model_prices[model_name]['prompt']
|
||||
elif text_type == 'completion':
|
||||
unit_price = llm_constant.model_prices[model_name]['completion']
|
||||
else:
|
||||
raise Exception('Invalid text type')
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
@classmethod
|
||||
def get_currency(cls, model_name: str):
|
||||
return llm_constant.model_currency
|
||||
@ -1,26 +0,0 @@
|
||||
import openai
|
||||
|
||||
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||
from models.provider import ProviderName
|
||||
from core.llm.provider.base import BaseProvider
|
||||
|
||||
|
||||
class Whisper:
|
||||
|
||||
def __init__(self, provider: BaseProvider):
|
||||
self.provider = provider
|
||||
|
||||
if self.provider.get_provider_name() == ProviderName.OPENAI:
|
||||
self.client = openai.Audio
|
||||
self.credentials = provider.get_credentials()
|
||||
|
||||
@handle_openai_exceptions
|
||||
def transcribe(self, file):
|
||||
return self.client.transcribe(
|
||||
model='whisper-1',
|
||||
file=file,
|
||||
api_key=self.credentials.get('openai_api_key'),
|
||||
api_base=self.credentials.get('openai_api_base'),
|
||||
api_type=self.credentials.get('openai_api_type'),
|
||||
api_version=self.credentials.get('openai_api_version'),
|
||||
)
|
||||
@ -1,27 +0,0 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
import anthropic
|
||||
|
||||
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_anthropic_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except anthropic.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to Anthropic API.")
|
||||
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
|
||||
except anthropic.RateLimitError:
|
||||
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(f"Anthropic: {e.message}")
|
||||
except anthropic.BadRequestError as e:
|
||||
raise LLMBadRequestError(f"Anthropic: {e.message}")
|
||||
except anthropic.APIStatusError as e:
|
||||
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
|
||||
|
||||
return wrapper
|
||||
@ -1,31 +0,0 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
import openai
|
||||
|
||||
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMBadRequestError
|
||||
|
||||
|
||||
def handle_openai_exceptions(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except openai.error.InvalidRequestError as e:
|
||||
logging.exception("Invalid request to OpenAI API.")
|
||||
raise LLMBadRequestError(str(e))
|
||||
except openai.error.APIConnectionError as e:
|
||||
logging.exception("Failed to connect to OpenAI API.")
|
||||
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
|
||||
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
|
||||
logging.exception("OpenAI service unavailable.")
|
||||
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
|
||||
except openai.error.RateLimitError as e:
|
||||
raise LLMRateLimitError(str(e))
|
||||
except openai.error.AuthenticationError as e:
|
||||
raise LLMAuthorizationError(str(e))
|
||||
except openai.error.OpenAIError as e:
|
||||
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
||||
|
||||
return wrapper
|
||||
@ -0,0 +1,293 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.callbacks.base import Callbacks
|
||||
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.speech2text.base import BaseSpeech2Text
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantDefaultModel
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
|
||||
@classmethod
|
||||
def get_text_generation_model_from_model_config(cls, tenant_id: str,
|
||||
model_config: dict,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None) -> Optional[BaseLLM]:
|
||||
provider_name = model_config.get("provider")
|
||||
model_name = model_config.get("name")
|
||||
completion_params = model_config.get("completion_params", {})
|
||||
|
||||
return cls.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
model_kwargs=ModelKwargs(
|
||||
temperature=completion_params.get('temperature', 0),
|
||||
max_tokens=completion_params.get('max_tokens', 256),
|
||||
top_p=completion_params.get('top_p', 0),
|
||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||
presence_penalty=completion_params.get('presence_penalty', 0.1)
|
||||
),
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_text_generation_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
model_kwargs: Optional[ModelKwargs] = None,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None) -> Optional[BaseLLM]:
|
||||
"""
|
||||
get text generation model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:param model_kwargs:
|
||||
:param streaming:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
is_default_model = False
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default System Reasoning Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
is_default_model = True
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init text generation model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
|
||||
|
||||
try:
|
||||
model_instance = model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
except LLMBadRequestError as e:
|
||||
if is_default_model:
|
||||
raise LLMBadRequestError(f"Default model {model_name} is not available. "
|
||||
f"Please check your model provider credentials.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
if is_default_model:
|
||||
model_instance.deduct_quota = False
|
||||
|
||||
return model_instance
|
||||
|
||||
@classmethod
|
||||
def get_embedding_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
|
||||
"""
|
||||
get embedding model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default Embedding Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init embedding model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_speech2text_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
|
||||
"""
|
||||
get speech to text model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
if model_provider_name is None and model_name is None:
|
||||
default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
|
||||
|
||||
if not default_model:
|
||||
raise LLMBadRequestError(f"Default model is not available. "
|
||||
f"Please configure a Default Speech-to-Text Model "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
model_provider_name = default_model.provider_name
|
||||
model_name = default_model.model_name
|
||||
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init speech to text model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_moderation_model(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: str,
|
||||
model_name: str) -> Optional[BaseProviderModel]:
|
||||
"""
|
||||
get moderation model.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
# get model provider
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
# init moderation model
|
||||
model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
|
||||
return model_class(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
|
||||
"""
|
||||
get default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
# get default model
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.value
|
||||
).first()
|
||||
|
||||
if not default_model:
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rules()
|
||||
for model_provider_name, model_provider_rule in model_provider_rules.items():
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
|
||||
if not model_provider:
|
||||
continue
|
||||
|
||||
model_list = model_provider.get_supported_model_list(model_type)
|
||||
if model_list:
|
||||
model_info = model_list[0]
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.value,
|
||||
provider_name=model_provider_name,
|
||||
model_name=model_info['id']
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
break
|
||||
|
||||
return default_model
|
||||
|
||||
@classmethod
|
||||
def update_default_model(cls,
|
||||
tenant_id: str,
|
||||
model_type: ModelType,
|
||||
provider_name: str,
|
||||
model_name: str) -> TenantDefaultModel:
|
||||
"""
|
||||
update default model of model type.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_type:
|
||||
:param provider_name:
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
model_provider_name = ModelProviderFactory.get_provider_names()
|
||||
if provider_name not in model_provider_name:
|
||||
raise ValueError(f'Invalid provider name: {provider_name}')
|
||||
|
||||
model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
|
||||
|
||||
if not model_provider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
|
||||
|
||||
model_list = model_provider.get_supported_model_list(model_type)
|
||||
model_ids = [model['id'] for model in model_list]
|
||||
if model_name not in model_ids:
|
||||
raise ValueError(f'Invalid model name: {model_name}')
|
||||
|
||||
# get default model
|
||||
default_model = db.session.query(TenantDefaultModel) \
|
||||
.filter(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.value
|
||||
).first()
|
||||
|
||||
if default_model:
|
||||
# update default model
|
||||
default_model.provider_name = provider_name
|
||||
default_model.model_name = model_name
|
||||
db.session.commit()
|
||||
else:
|
||||
# create default model
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.value,
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
|
||||
return default_model
|
||||
@ -0,0 +1,228 @@
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.rules import provider_rules
|
||||
from extensions.ext_database import db
|
||||
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
|
||||
|
||||
DEFAULT_MODELS = {
|
||||
ModelType.TEXT_GENERATION.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'gpt-3.5-turbo',
|
||||
},
|
||||
ModelType.EMBEDDINGS.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'text-embedding-ada-002',
|
||||
},
|
||||
ModelType.SPEECH_TO_TEXT.value: {
|
||||
'provider_name': 'openai',
|
||||
'model_name': 'whisper-1',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
@classmethod
|
||||
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
|
||||
if provider_name == 'openai':
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
return OpenAIProvider
|
||||
elif provider_name == 'anthropic':
|
||||
from core.model_providers.providers.anthropic_provider import AnthropicProvider
|
||||
return AnthropicProvider
|
||||
elif provider_name == 'minimax':
|
||||
from core.model_providers.providers.minimax_provider import MinimaxProvider
|
||||
return MinimaxProvider
|
||||
elif provider_name == 'spark':
|
||||
from core.model_providers.providers.spark_provider import SparkProvider
|
||||
return SparkProvider
|
||||
elif provider_name == 'tongyi':
|
||||
from core.model_providers.providers.tongyi_provider import TongyiProvider
|
||||
return TongyiProvider
|
||||
elif provider_name == 'wenxin':
|
||||
from core.model_providers.providers.wenxin_provider import WenxinProvider
|
||||
return WenxinProvider
|
||||
elif provider_name == 'chatglm':
|
||||
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
|
||||
return ChatGLMProvider
|
||||
elif provider_name == 'azure_openai':
|
||||
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
return AzureOpenAIProvider
|
||||
elif provider_name == 'replicate':
|
||||
from core.model_providers.providers.replicate_provider import ReplicateProvider
|
||||
return ReplicateProvider
|
||||
elif provider_name == 'huggingface_hub':
|
||||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
|
||||
return HuggingfaceHubProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_provider_names(cls):
|
||||
"""
|
||||
Returns a list of provider names.
|
||||
"""
|
||||
return list(provider_rules.keys())
|
||||
|
||||
@classmethod
|
||||
def get_provider_rules(cls):
|
||||
"""
|
||||
Returns a list of provider rules.
|
||||
|
||||
:return:
|
||||
"""
|
||||
return provider_rules
|
||||
|
||||
@classmethod
|
||||
def get_provider_rule(cls, provider_name: str):
|
||||
"""
|
||||
Returns provider rule.
|
||||
"""
|
||||
return provider_rules[provider_name]
|
||||
|
||||
@classmethod
|
||||
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred model provider.
|
||||
|
||||
:param tenant_id: a string representing the ID of the tenant.
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
# get preferred provider
|
||||
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
|
||||
if not preferred_provider or not preferred_provider.is_valid:
|
||||
return None
|
||||
|
||||
# init model provider
|
||||
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
|
||||
return model_provider_class(provider=preferred_provider)
|
||||
|
||||
@classmethod
|
||||
def get_preferred_type_by_preferred_model_provider(cls,
|
||||
tenant_id: str,
|
||||
model_provider_name: str,
|
||||
preferred_model_provider: TenantPreferredModelProvider):
|
||||
"""
|
||||
get preferred provider type by preferred model provider.
|
||||
|
||||
:param model_provider_name:
|
||||
:param preferred_model_provider:
|
||||
:return:
|
||||
"""
|
||||
if not preferred_model_provider:
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
||||
support_provider_types = model_provider_rules['support_provider_types']
|
||||
|
||||
if ProviderType.CUSTOM.value in support_provider_types:
|
||||
custom_provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.is_valid == True
|
||||
).first()
|
||||
|
||||
if custom_provider:
|
||||
return ProviderType.CUSTOM.value
|
||||
|
||||
model_provider = cls.get_model_provider_class(model_provider_name)
|
||||
|
||||
if ProviderType.SYSTEM.value in support_provider_types \
|
||||
and model_provider.is_provider_type_system_supported():
|
||||
return ProviderType.SYSTEM.value
|
||||
elif ProviderType.CUSTOM.value in support_provider_types:
|
||||
return ProviderType.CUSTOM.value
|
||||
else:
|
||||
return preferred_model_provider.preferred_provider_type
|
||||
|
||||
@classmethod
|
||||
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred provider of tenant.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
# get preferred provider type
|
||||
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
|
||||
|
||||
# get providers by preferred provider type
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == preferred_provider_type
|
||||
).all()
|
||||
|
||||
no_system_provider = False
|
||||
if preferred_provider_type == ProviderType.SYSTEM.value:
|
||||
quota_type_to_provider_dict = {}
|
||||
for provider in providers:
|
||||
quota_type_to_provider_dict[provider.quota_type] = provider
|
||||
|
||||
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
|
||||
for quota_type_enum in ProviderQuotaType:
|
||||
quota_type = quota_type_enum.value
|
||||
if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
|
||||
and quota_type in quota_type_to_provider_dict.keys():
|
||||
provider = quota_type_to_provider_dict[quota_type]
|
||||
if provider.is_valid and provider.quota_limit > provider.quota_used:
|
||||
return provider
|
||||
|
||||
no_system_provider = True
|
||||
|
||||
if no_system_provider:
|
||||
providers = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).all()
|
||||
|
||||
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
|
||||
if providers:
|
||||
return providers[0]
|
||||
else:
|
||||
try:
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=model_provider_name,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
is_valid=False
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
provider = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == model_provider_name,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
|
||||
"""
|
||||
get preferred provider type of tenant.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_provider_name:
|
||||
:return:
|
||||
"""
|
||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == model_provider_name
|
||||
).first()
|
||||
|
||||
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)
|
||||
@ -0,0 +1,22 @@
|
||||
from abc import ABC
|
||||
from typing import Any
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class BaseProviderModel(ABC):
|
||||
_client: Any
|
||||
_model_provider: BaseModelProvider
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, client: Any):
|
||||
self._model_provider = model_provider
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def model_provider(self):
|
||||
return self._model_provider
|
||||
|
||||
@ -0,0 +1,78 @@
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \
|
||||
LLMAPIUnavailableError, LLMAPIConnectionError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
self.credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = OpenAIEmbeddings(
|
||||
deployment=name,
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
chunk_size=16,
|
||||
max_retries=1,
|
||||
**self.credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name'))
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to Azure OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("Azure OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
@ -0,0 +1,40 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema.language_model import _get_token_ids_default_method
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class BaseEmbedding(BaseProviderModel):
|
||||
name: str
|
||||
type: ModelType = ModelType.EMBEDDINGS
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
return len(_get_token_ids_default_method(text))
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return 0
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
raise NotImplementedError
|
||||
@ -0,0 +1,35 @@
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
from langchain.embeddings import MiniMaxEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class MinimaxEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = MiniMaxEmbeddings(
|
||||
model=name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
@ -0,0 +1,72 @@
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
|
||||
class OpenAIEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = OpenAIEmbeddings(
|
||||
max_retries=1,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return 0
|
||||
|
||||
enc = tiktoken.encoding_for_model(self.name)
|
||||
|
||||
tokenized_text = enc.encode(text)
|
||||
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
@ -0,0 +1,36 @@
|
||||
import decimal
|
||||
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class ReplicateEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = ReplicateEmbeddings(
|
||||
model=name + ':' + credentials.get('model_version'),
|
||||
replicate_api_token=credentials.get('replicate_api_token')
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
@ -0,0 +1,53 @@
|
||||
import enum
|
||||
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMRunResult(BaseModel):
|
||||
content: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
|
||||
class MessageType(enum.Enum):
|
||||
HUMAN = 'human'
|
||||
ASSISTANT = 'assistant'
|
||||
SYSTEM = 'system'
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
type: MessageType = MessageType.HUMAN
|
||||
content: str = ''
|
||||
|
||||
|
||||
def to_lc_messages(messages: list[PromptMessage]):
|
||||
lc_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
lc_messages.append(AIMessage(content=message.content))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
lc_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
return lc_messages
|
||||
|
||||
|
||||
def to_prompt_messages(messages: list[BaseMessage]):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
|
||||
elif isinstance(message, AIMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
|
||||
elif isinstance(message, SystemMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def str_to_prompt_messages(texts: list[str]):
|
||||
prompt_messages = []
|
||||
for text in texts:
|
||||
prompt_messages.append(PromptMessage(content=text))
|
||||
return prompt_messages
|
||||
@ -0,0 +1,59 @@
|
||||
import enum
|
||||
from typing import Optional, TypeVar, Generic
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelMode(enum.Enum):
|
||||
COMPLETION = 'completion'
|
||||
CHAT = 'chat'
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
TEXT_GENERATION = 'text-generation'
|
||||
EMBEDDINGS = 'embeddings'
|
||||
SPEECH_TO_TEXT = 'speech2text'
|
||||
IMAGE = 'image'
|
||||
VIDEO = 'video'
|
||||
MODERATION = 'moderation'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ModelType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ModelKwargs(BaseModel):
|
||||
max_tokens: Optional[int]
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
presence_penalty: Optional[float]
|
||||
frequency_penalty: Optional[float]
|
||||
|
||||
|
||||
class KwargRuleType(enum.Enum):
|
||||
STRING = 'string'
|
||||
INTEGER = 'integer'
|
||||
FLOAT = 'float'
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class KwargRule(Generic[T], BaseModel):
|
||||
enabled: bool = True
|
||||
min: Optional[T] = None
|
||||
max: Optional[T] = None
|
||||
default: Optional[T] = None
|
||||
alias: Optional[str] = None
|
||||
|
||||
|
||||
class ModelKwargsRules(BaseModel):
|
||||
max_tokens: KwargRule = KwargRule[int](enabled=False)
|
||||
temperature: KwargRule = KwargRule[float](enabled=False)
|
||||
top_p: KwargRule = KwargRule[float](enabled=False)
|
||||
presence_penalty: KwargRule = KwargRule[float](enabled=False)
|
||||
frequency_penalty: KwargRule = KwargRule[float](enabled=False)
|
||||
@ -0,0 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ProviderQuotaUnit(Enum):
|
||||
TIMES = 'times'
|
||||
TOKENS = 'tokens'
|
||||
|
||||
|
||||
class ModelFeature(Enum):
|
||||
AGENT_THOUGHT = 'agent_thought'
|
||||
@ -0,0 +1,107 @@
|
||||
import decimal
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import anthropic
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class AnthropicModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatAnthropic(
|
||||
model=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
default_request_timeout=60,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': decimal.Decimal('1.63'),
|
||||
'completion': decimal.Decimal('5.51'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': decimal.Decimal('11.02'),
|
||||
'completion': decimal.Decimal('32.68'),
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1m * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, anthropic.APIConnectionError):
|
||||
logging.warning("Failed to connect to Anthropic API.")
|
||||
return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}")
|
||||
elif isinstance(ex, anthropic.RateLimitError):
|
||||
return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
|
||||
elif isinstance(ex, anthropic.AuthenticationError):
|
||||
return LLMAuthorizationError(f"Anthropic: {ex.message}")
|
||||
elif isinstance(ex, anthropic.BadRequestError):
|
||||
return LLMBadRequestError(f"Anthropic: {ex.message}")
|
||||
elif isinstance(ex, anthropic.APIStatusError):
|
||||
return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
|
||||
@ -0,0 +1,177 @@
|
||||
import decimal
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI
|
||||
from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
class AzureOpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
if name == 'text-davinci-003':
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.name == 'text-davinci-003':
|
||||
client = EnhanceAzureOpenAI(
|
||||
deployment_name=self.name,
|
||||
streaming=self.streaming,
|
||||
request_timeout=60,
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
openai_api_key=self.credentials.get('openai_api_key'),
|
||||
openai_api_base=self.credentials.get('openai_api_base'),
|
||||
callbacks=self.callbacks,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
client = EnhanceAzureChatOpenAI(
|
||||
deployment_name=self.name,
|
||||
temperature=provider_model_kwargs.get('temperature'),
|
||||
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||
model_kwargs=extra_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
request_timeout=60,
|
||||
openai_api_type='azure',
|
||||
openai_api_version=AZURE_OPENAI_API_VERSION,
|
||||
openai_api_key=self.credentials.get('openai_api_key'),
|
||||
openai_api_base=self.credentials.get('openai_api_base'),
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
if isinstance(prompts, str):
|
||||
return self._client.get_num_tokens(prompts)
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-35-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-35-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
base_model_name = self.credentials.get("base_model_name")
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[base_model_name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[base_model_name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name == 'text-davinci-003':
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||
self.client.model_kwargs = extra_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to Azure OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("Azure OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError('Azure ' + str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@ -0,0 +1,269 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class BaseLLM(BaseProviderModel):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
name: str
|
||||
model_kwargs: ModelKwargs
|
||||
credentials: dict
|
||||
streaming: bool = False
|
||||
type: ModelType = ModelType.TEXT_GENERATION
|
||||
deduct_quota: bool = True
|
||||
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
self.name = name
|
||||
self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
|
||||
self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
|
||||
max_tokens=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
presence_penalty=None,
|
||||
frequency_penalty=None
|
||||
)
|
||||
self.credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
self.streaming = streaming
|
||||
|
||||
if streaming:
|
||||
default_callback = DifyStreamingStdOutCallbackHandler()
|
||||
else:
|
||||
default_callback = DifyStdOutCallbackHandler()
|
||||
|
||||
if not callbacks:
|
||||
callbacks = [default_callback]
|
||||
else:
|
||||
callbacks.append(default_callback)
|
||||
|
||||
self.callbacks = callbacks
|
||||
|
||||
client = self._init_client()
|
||||
super().__init__(model_provider, client)
|
||||
|
||||
@abstractmethod
|
||||
def _init_client(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMRunResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if self.deduct_quota:
|
||||
self.model_provider.check_quota_over_limit()
|
||||
|
||||
if not callbacks:
|
||||
callbacks = self.callbacks
|
||||
else:
|
||||
callbacks.extend(self.callbacks)
|
||||
|
||||
if 'fake_response' in kwargs and kwargs['fake_response']:
|
||||
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
|
||||
fake_llm = FakeLLM(
|
||||
response=kwargs['fake_response'],
|
||||
num_token_func=self.get_num_tokens,
|
||||
streaming=self.streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
result = fake_llm.generate([prompts])
|
||||
else:
|
||||
try:
|
||||
result = self._run(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
if isinstance(result.generations[0][0], ChatGeneration):
|
||||
completion_content = result.generations[0][0].message.content
|
||||
else:
|
||||
completion_content = result.generations[0][0].text
|
||||
|
||||
if self.streaming and not self.support_streaming():
|
||||
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
|
||||
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
|
||||
fake_llm = FakeLLM(
|
||||
response=completion_content,
|
||||
num_token_func=self.get_num_tokens,
|
||||
streaming=self.streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
fake_llm.generate([prompts])
|
||||
|
||||
if result.llm_output and result.llm_output['token_usage']:
|
||||
prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
|
||||
completion_tokens = result.llm_output['token_usage']['completion_tokens']
|
||||
total_tokens = result.llm_output['token_usage']['total_tokens']
|
||||
else:
|
||||
prompt_tokens = self.get_num_tokens(messages)
|
||||
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
if self.deduct_quota:
|
||||
self.model_provider.deduct_quota(total_tokens)
|
||||
|
||||
return LLMRunResult(
|
||||
content=completion_content,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:param tokens:
|
||||
:param message_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_currency(self):
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_kwargs(self):
|
||||
return self.model_kwargs
|
||||
|
||||
def set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
self.model_kwargs = model_kwargs
|
||||
self._set_model_kwargs(model_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
"""
|
||||
Handle llm run exceptions.
|
||||
|
||||
:param ex:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_callbacks(self, callbacks: Callbacks):
|
||||
"""
|
||||
Add callbacks to client.
|
||||
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if not self.client.callbacks:
|
||||
self.client.callbacks = callbacks
|
||||
else:
|
||||
self.client.callbacks.extend(callbacks)
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("prompt must not be empty.")
|
||||
|
||||
if not model_mode:
|
||||
model_mode = self.model_mode
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
return messages[0].content
|
||||
else:
|
||||
chat_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
chat_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
chat_messages.append(AIMessage(content=message.content))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
chat_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
return chat_messages
|
||||
|
||||
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
|
||||
"""
|
||||
convert model kwargs to provider model kwargs.
|
||||
|
||||
:param model_rules:
|
||||
:param model_kwargs:
|
||||
:return:
|
||||
"""
|
||||
model_kwargs_input = {}
|
||||
for key, value in model_kwargs.dict().items():
|
||||
rule = getattr(model_rules, key)
|
||||
if not rule.enabled:
|
||||
continue
|
||||
|
||||
if rule.alias:
|
||||
key = rule.alias
|
||||
|
||||
if rule.default is not None and value is None:
|
||||
value = rule.default
|
||||
|
||||
if rule.min is not None:
|
||||
value = max(value, rule.min)
|
||||
|
||||
if rule.max is not None:
|
||||
value = min(value, rule.max)
|
||||
|
||||
model_kwargs_input[key] = value
|
||||
|
||||
return model_kwargs_input
|
||||
@ -0,0 +1,70 @@
|
||||
import decimal
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import ChatGLM
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class ChatGLMModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatGLM(
|
||||
callbacks=self.callbacks,
|
||||
endpoint_url=self.credentials.get('api_base'),
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
@ -0,0 +1,82 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class HuggingfaceHubModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
client = HuggingFaceEndpoint(
|
||||
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
|
||||
task='text2text-generation',
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
repo_id=self.name,
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# not support calc price
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.model_kwargs = provider_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
@ -0,0 +1,70 @@
|
||||
import decimal
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import Minimax
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class MinimaxModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return Minimax(
|
||||
model=self.name,
|
||||
model_kwargs={
|
||||
'stream': False
|
||||
},
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
@ -0,0 +1,219 @@
|
||||
import decimal
|
||||
import logging
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
COMPLETION_MODELS = [
|
||||
'text-davinci-003', # 4,097 tokens
|
||||
]
|
||||
|
||||
CHAT_MODELS = [
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
]
|
||||
|
||||
MODEL_MAX_TOKENS = {
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
|
||||
class OpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
if name in COMPLETION_MODELS:
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.name in COMPLETION_MODELS:
|
||||
client = EnhanceOpenAI(
|
||||
model_name=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
request_timeout=60,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
else:
|
||||
# Fine-tuning is currently only available for the following base models:
|
||||
# davinci, curie, babbage, and ada.
|
||||
# This means that except for the fixed `completion` model,
|
||||
# all other fine-tuned models are `completion` models.
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
client = EnhanceChatOpenAI(
|
||||
model_name=self.name,
|
||||
temperature=provider_model_kwargs.get('temperature'),
|
||||
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||
model_kwargs=extra_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
request_timeout=60,
|
||||
**self.credentials
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if self.name == 'gpt-4' \
|
||||
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
||||
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
if isinstance(prompts, str):
|
||||
return self._client.get_num_tokens(prompts)
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name in COMPLETION_MODELS:
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
else:
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p'),
|
||||
'frequency_penalty': provider_model_kwargs.get('frequency_penalty'),
|
||||
'presence_penalty': provider_model_kwargs.get('presence_penalty'),
|
||||
}
|
||||
|
||||
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||
self.client.model_kwargs = extra_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to OpenAI API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("OpenAI service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
raise LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
|
||||
# def is_model_valid_or_raise(self):
|
||||
# """
|
||||
# check is a valid model.
|
||||
#
|
||||
# :return:
|
||||
# """
|
||||
# credentials = self._model_provider.get_credentials()
|
||||
#
|
||||
# try:
|
||||
# result = openai.Model.retrieve(
|
||||
# id=self.name,
|
||||
# api_key=credentials.get('openai_api_key'),
|
||||
# request_timeout=60
|
||||
# )
|
||||
#
|
||||
# if 'id' not in result or result['id'] != self.name:
|
||||
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.")
|
||||
# except openai.error.OpenAIError as e:
|
||||
# raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
# except Exception as e:
|
||||
# logging.exception("OpenAI Model retrieve failed.")
|
||||
# raise e
|
||||
@ -0,0 +1,103 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult, get_buffer_string
|
||||
from replicate.exceptions import ReplicateError, ModelError
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
|
||||
|
||||
class ReplicateModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
model_kwargs: ModelKwargs,
|
||||
streaming: bool = False,
|
||||
callbacks: Callbacks = None):
|
||||
self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
|
||||
return EnhanceReplicate(
|
||||
model=self.name + ':' + self.credentials.get('model_version'),
|
||||
input=provider_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
replicate_api_token=self.credentials.get('replicate_api_token'),
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
extra_kwargs = {}
|
||||
if isinstance(prompts, list):
|
||||
system_messages = [message for message in messages if message.type == 'system']
|
||||
if system_messages:
|
||||
system_message = system_messages[0]
|
||||
extra_kwargs['system_prompt'] = system_message.content
|
||||
prompts = [message for message in messages if message.type != 'system']
|
||||
|
||||
prompts = get_buffer_string(prompts)
|
||||
|
||||
# The maximum length the generated tokens can have.
|
||||
# Corresponds to the length of the input prompt + max_new_tokens.
|
||||
if 'max_length' in self._client.input:
|
||||
self._client.input['max_length'] = min(
|
||||
self._client.input['max_length'] + self.get_num_tokens(messages),
|
||||
self.model_rules.max_tokens.max
|
||||
)
|
||||
|
||||
return self._client.generate([prompts], stop, callbacks, **extra_kwargs)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
if isinstance(prompts, list):
|
||||
prompts = get_buffer_string(prompts)
|
||||
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.input = provider_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@ -0,0 +1,73 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.spark import ChatSpark
|
||||
from core.third_party.spark.spark_llm import SparkError
|
||||
|
||||
|
||||
class SparkModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatSpark(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
contents = [message.content for message in messages]
|
||||
return max(self._client.get_num_tokens("".join(contents)), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, SparkError):
|
||||
return LLMBadRequestError(f"Spark: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue