|
|
|
|
@ -3,6 +3,7 @@ from typing import Optional, Generator, Union, List
|
|
|
|
|
import google.generativeai as genai
|
|
|
|
|
import google.api_core.exceptions as exceptions
|
|
|
|
|
import google.generativeai.client as client
|
|
|
|
|
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
|
|
|
|
|
|
|
|
|
from google.generativeai.types import GenerateContentResponse, ContentType
|
|
|
|
|
from google.generativeai.types.content_types import to_part
|
|
|
|
|
@ -140,12 +141,20 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
|
|
|
|
google_model._client = new_custom_client
|
|
|
|
|
|
|
|
|
|
safety_settings={
|
|
|
|
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
|
|
|
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
|
|
|
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
|
|
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
response = google_model.generate_content(
|
|
|
|
|
contents=history,
|
|
|
|
|
generation_config=genai.types.GenerationConfig(
|
|
|
|
|
**config_kwargs
|
|
|
|
|
),
|
|
|
|
|
stream=stream
|
|
|
|
|
stream=stream,
|
|
|
|
|
safety_settings=safety_settings
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if stream:
|
|
|
|
|
@ -169,7 +178,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
content=response.text
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# calculate num tokens
|
|
|
|
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
|
|
|
|
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
|
|
|
|
|