@ -1,17 +1,24 @@
import json
import os
import re
from abc import abstractmethod
from typing import List , Optional , Any , Union
from typing import List , Optional , Any , Union , Tuple
import decimal
from langchain . callbacks . manager import Callbacks
from langchain . memory . chat_memory import BaseChatMemory
from langchain . schema import LLMResult , SystemMessage , AIMessage , HumanMessage , BaseMessage , ChatGeneration
from core . callback_handler . std_out_callback_handler import DifyStreamingStdOutCallbackHandler , DifyStdOutCallbackHandler
from core . model_providers . models . base import BaseProviderModel
from core . model_providers . models . entity . message import PromptMessage , MessageType , LLMRunResult
from core . model_providers . models . entity . message import PromptMessage , MessageType , LLMRunResult , to_prompt_messages
from core . model_providers . models . entity . model_params import ModelType , ModelKwargs , ModelMode , ModelKwargsRules
from core . model_providers . providers . base import BaseModelProvider
from core . prompt . prompt_builder import PromptBuilder
from core . prompt . prompt_template import JinjaPromptTemplate
from core . third_party . langchain . llms . fake import FakeLLM
import logging
logger = logging . getLogger ( __name__ )
@ -76,13 +83,14 @@ class BaseLLM(BaseProviderModel):
def price_config ( self ) - > dict :
def get_or_default ( ) :
default_price_config = {
' prompt ' : decimal . Decimal ( ' 0 ' ) ,
' completion ' : decimal . Decimal ( ' 0 ' ) ,
' unit ' : decimal . Decimal ( ' 0 ' ) ,
' currency ' : ' USD '
}
' prompt ' : decimal . Decimal ( ' 0 ' ) ,
' completion ' : decimal . Decimal ( ' 0 ' ) ,
' unit ' : decimal . Decimal ( ' 0 ' ) ,
' currency ' : ' USD '
}
rules = self . model_provider . get_rules ( )
price_config = rules [ ' price_config ' ] [ self . base_model_name ] if ' price_config ' in rules else default_price_config
price_config = rules [ ' price_config ' ] [
self . base_model_name ] if ' price_config ' in rules else default_price_config
price_config = {
' prompt ' : decimal . Decimal ( price_config [ ' prompt ' ] ) ,
' completion ' : decimal . Decimal ( price_config [ ' completion ' ] ) ,
@ -90,7 +98,7 @@ class BaseLLM(BaseProviderModel):
' currency ' : price_config [ ' currency ' ]
}
return price_config
self . _price_config = self . _price_config if hasattr ( self , ' _price_config ' ) else get_or_default ( )
logger . debug ( f " model: { self . name } price_config: { self . _price_config } " )
@ -158,7 +166,8 @@ class BaseLLM(BaseProviderModel):
total_tokens = result . llm_output [ ' token_usage ' ] [ ' total_tokens ' ]
else :
prompt_tokens = self . get_num_tokens ( messages )
completion_tokens = self . get_num_tokens ( [ PromptMessage ( content = completion_content , type = MessageType . ASSISTANT ) ] )
completion_tokens = self . get_num_tokens (
[ PromptMessage ( content = completion_content , type = MessageType . ASSISTANT ) ] )
total_tokens = prompt_tokens + completion_tokens
self . model_provider . update_last_used ( )
@ -293,6 +302,119 @@ class BaseLLM(BaseProviderModel):
def support_streaming ( cls ) :
return False
def get_prompt ( self , mode : str ,
pre_prompt : str , inputs : dict ,
query : str ,
context : Optional [ str ] ,
memory : Optional [ BaseChatMemory ] ) - > \
Tuple [ List [ PromptMessage ] , Optional [ List [ str ] ] ] :
prompt_rules = self . _read_prompt_rules_from_file ( self . prompt_file_name ( mode ) )
prompt , stops = self . _get_prompt_and_stop ( prompt_rules , pre_prompt , inputs , query , context , memory )
return [ PromptMessage ( content = prompt ) ] , stops
def prompt_file_name ( self , mode : str ) - > str :
if mode == ' completion ' :
return ' common_completion '
else :
return ' common_chat '
def _get_prompt_and_stop ( self , prompt_rules : dict , pre_prompt : str , inputs : dict ,
query : str ,
context : Optional [ str ] ,
memory : Optional [ BaseChatMemory ] ) - > Tuple [ str , Optional [ list ] ] :
context_prompt_content = ' '
if context and ' context_prompt ' in prompt_rules :
prompt_template = JinjaPromptTemplate . from_template ( template = prompt_rules [ ' context_prompt ' ] )
context_prompt_content = prompt_template . format (
context = context
)
pre_prompt_content = ' '
if pre_prompt :
prompt_template = JinjaPromptTemplate . from_template ( template = pre_prompt )
prompt_inputs = { k : inputs [ k ] for k in prompt_template . input_variables if k in inputs }
pre_prompt_content = prompt_template . format (
* * prompt_inputs
)
prompt = ' '
for order in prompt_rules [ ' system_prompt_orders ' ] :
if order == ' context_prompt ' :
prompt + = context_prompt_content
elif order == ' pre_prompt ' :
prompt + = ( pre_prompt_content + ' \n \n ' ) if pre_prompt_content else ' '
query_prompt = prompt_rules [ ' query_prompt ' ] if ' query_prompt ' in prompt_rules else ' {{ query}} '
if memory and ' histories_prompt ' in prompt_rules :
# append chat histories
tmp_human_message = PromptBuilder . to_human_message (
prompt_content = prompt + query_prompt ,
inputs = {
' query ' : query
}
)
if self . model_rules . max_tokens . max :
curr_message_tokens = self . get_num_tokens ( to_prompt_messages ( [ tmp_human_message ] ) )
max_tokens = self . model_kwargs . max_tokens
rest_tokens = self . model_rules . max_tokens . max - max_tokens - curr_message_tokens
rest_tokens = max ( rest_tokens , 0 )
else :
rest_tokens = 2000
memory . human_prefix = prompt_rules [ ' human_prefix ' ] if ' human_prefix ' in prompt_rules else ' Human '
memory . ai_prefix = prompt_rules [ ' assistant_prefix ' ] if ' assistant_prefix ' in prompt_rules else ' Assistant '
histories = self . _get_history_messages_from_memory ( memory , rest_tokens )
prompt_template = JinjaPromptTemplate . from_template ( template = prompt_rules [ ' histories_prompt ' ] )
histories_prompt_content = prompt_template . format (
histories = histories
)
prompt = ' '
for order in prompt_rules [ ' system_prompt_orders ' ] :
if order == ' context_prompt ' :
prompt + = context_prompt_content
elif order == ' pre_prompt ' :
prompt + = ( pre_prompt_content + ' \n ' ) if pre_prompt_content else ' '
elif order == ' histories_prompt ' :
prompt + = histories_prompt_content
prompt_template = JinjaPromptTemplate . from_template ( template = query_prompt )
query_prompt_content = prompt_template . format (
query = query
)
prompt + = query_prompt_content
prompt = re . sub ( r ' < \ |.*? \ |> ' , ' ' , prompt )
stops = prompt_rules . get ( ' stops ' )
if stops is not None and len ( stops ) == 0 :
stops = None
return prompt , stops
def _read_prompt_rules_from_file ( self , prompt_name : str ) - > dict :
# Get the absolute path of the subdirectory
prompt_path = os . path . join (
os . path . dirname ( os . path . dirname ( os . path . dirname ( os . path . dirname ( os . path . realpath ( __file__ ) ) ) ) ) ,
' prompt/generate_prompts ' )
json_file_path = os . path . join ( prompt_path , f ' { prompt_name } .json ' )
# Open the JSON file and read its content
with open ( json_file_path , ' r ' ) as json_file :
return json . load ( json_file )
def _get_history_messages_from_memory ( self , memory : BaseChatMemory ,
max_token_limit : int ) - > str :
""" Get memory messages. """
memory . max_token_limit = max_token_limit
memory_key = memory . memory_variables [ 0 ]
external_context = memory . load_memory_variables ( { } )
return external_context [ memory_key ]
def _get_prompt_from_messages ( self , messages : List [ PromptMessage ] ,
model_mode : Optional [ ModelMode ] = None ) - > Union [ str | List [ BaseMessage ] ] :
if not model_mode :