|
|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
from abc import abstractmethod
|
|
|
|
|
from typing import List, Optional, Any, Union
|
|
|
|
|
import decimal
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import Callbacks
|
|
|
|
|
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
|
|
|
|
@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
|
|
|
|
|
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
|
|
|
|
from core.model_providers.providers.base import BaseModelProvider
|
|
|
|
|
from core.third_party.langchain.llms.fake import FakeLLM
|
|
|
|
|
import logging
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseLLM(BaseProviderModel):
|
|
|
|
|
@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
|
|
|
|
|
def _init_client(self) -> Any:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def base_model_name(self) -> str:
|
|
|
|
|
"""
|
|
|
|
|
get llm base model name
|
|
|
|
|
|
|
|
|
|
:return: str
|
|
|
|
|
"""
|
|
|
|
|
return self.name
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def price_config(self) -> dict:
|
|
|
|
|
def get_or_default():
|
|
|
|
|
default_price_config = {
|
|
|
|
|
'prompt': decimal.Decimal('0'),
|
|
|
|
|
'completion': decimal.Decimal('0'),
|
|
|
|
|
'unit': decimal.Decimal('0'),
|
|
|
|
|
'currency': 'USD'
|
|
|
|
|
}
|
|
|
|
|
rules = self.model_provider.get_rules()
|
|
|
|
|
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
|
|
|
|
price_config = {
|
|
|
|
|
'prompt': decimal.Decimal(price_config['prompt']),
|
|
|
|
|
'completion': decimal.Decimal(price_config['completion']),
|
|
|
|
|
'unit': decimal.Decimal(price_config['unit']),
|
|
|
|
|
'currency': price_config['currency']
|
|
|
|
|
}
|
|
|
|
|
return price_config
|
|
|
|
|
|
|
|
|
|
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
|
|
|
|
|
|
|
|
|
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
|
|
|
|
return self._price_config
|
|
|
|
|
|
|
|
|
|
def run(self, messages: List[PromptMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
callbacks: Callbacks = None,
|
|
|
|
|
@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel):
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
|
|
|
def calc_tokens_price(self, tokens:int, message_type: MessageType):
|
|
|
|
|
"""
|
|
|
|
|
get token price.
|
|
|
|
|
calc tokens total price.
|
|
|
|
|
|
|
|
|
|
:param tokens:
|
|
|
|
|
:param message_type:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
|
|
|
|
unit_price = self.price_config['prompt']
|
|
|
|
|
else:
|
|
|
|
|
unit_price = self.price_config['completion']
|
|
|
|
|
unit = self.price_config['unit']
|
|
|
|
|
|
|
|
|
|
total_price = tokens * unit_price * unit
|
|
|
|
|
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
|
|
|
|
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
|
|
|
|
return total_price
|
|
|
|
|
|
|
|
|
|
def get_tokens_unit_price(self, message_type: MessageType):
|
|
|
|
|
"""
|
|
|
|
|
get token price.
|
|
|
|
|
|
|
|
|
|
:param message_type:
|
|
|
|
|
:return: decimal.Decimal('0.0001')
|
|
|
|
|
"""
|
|
|
|
|
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
|
|
|
|
unit_price = self.price_config['prompt']
|
|
|
|
|
else:
|
|
|
|
|
unit_price = self.price_config['completion']
|
|
|
|
|
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
|
|
|
|
logging.debug(f"unit_price={unit_price}")
|
|
|
|
|
return unit_price
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_currency(self):
|
|
|
|
|
"""
|
|
|
|
|
get token currency.
|
|
|
|
|
|
|
|
|
|
:return:
|
|
|
|
|
:return: get from price config, default 'USD'
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
currency = self.price_config['currency']
|
|
|
|
|
return currency
|
|
|
|
|
|
|
|
|
|
def get_model_kwargs(self):
|
|
|
|
|
return self.model_kwargs
|
|
|
|
|
|