|
|
|
|
@ -1,29 +1,30 @@
|
|
|
|
|
import base64
|
|
|
|
|
import io
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import tempfile
|
|
|
|
|
import time
|
|
|
|
|
from collections.abc import Generator
|
|
|
|
|
from typing import Optional, Union, cast
|
|
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
|
|
|
|
import google.ai.generativelanguage as glm
|
|
|
|
|
import google.generativeai as genai
|
|
|
|
|
import requests
|
|
|
|
|
from google.api_core import exceptions
|
|
|
|
|
from google.generativeai.client import _ClientManager
|
|
|
|
|
from google.generativeai.types import ContentType, GenerateContentResponse
|
|
|
|
|
from google.generativeai.types import ContentType, File, GenerateContentResponse
|
|
|
|
|
from google.generativeai.types.content_types import to_part
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
|
|
AssistantPromptMessage,
|
|
|
|
|
DocumentPromptMessageContent,
|
|
|
|
|
ImagePromptMessageContent,
|
|
|
|
|
PromptMessage,
|
|
|
|
|
PromptMessageContent,
|
|
|
|
|
PromptMessageContentType,
|
|
|
|
|
PromptMessageTool,
|
|
|
|
|
SystemPromptMessage,
|
|
|
|
|
ToolPromptMessage,
|
|
|
|
|
UserPromptMessage,
|
|
|
|
|
VideoPromptMessageContent,
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.errors.invoke import (
|
|
|
|
|
InvokeAuthorizationError,
|
|
|
|
|
@ -35,21 +36,7 @@ from core.model_runtime.errors.invoke import (
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
|
|
|
|
|
|
|
GOOGLE_AVAILABLE_MIMETYPE = [
|
|
|
|
|
"application/pdf",
|
|
|
|
|
"application/x-javascript",
|
|
|
|
|
"text/javascript",
|
|
|
|
|
"application/x-python",
|
|
|
|
|
"text/x-python",
|
|
|
|
|
"text/plain",
|
|
|
|
|
"text/html",
|
|
|
|
|
"text/css",
|
|
|
|
|
"text/md",
|
|
|
|
|
"text/csv",
|
|
|
|
|
"text/xml",
|
|
|
|
|
"text/rtf",
|
|
|
|
|
]
|
|
|
|
|
from extensions.ext_redis import redis_client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
@ -201,29 +188,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
if stop:
|
|
|
|
|
config_kwargs["stop_sequences"] = stop
|
|
|
|
|
|
|
|
|
|
genai.configure(api_key=credentials["google_api_key"])
|
|
|
|
|
google_model = genai.GenerativeModel(model_name=model)
|
|
|
|
|
|
|
|
|
|
history = []
|
|
|
|
|
|
|
|
|
|
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
|
|
|
|
if model == "gemini-pro-vision":
|
|
|
|
|
last_msg = prompt_messages[-1]
|
|
|
|
|
content = self._format_message_to_glm_content(last_msg)
|
|
|
|
|
history.append(content)
|
|
|
|
|
else:
|
|
|
|
|
for msg in prompt_messages: # makes message roles strictly alternating
|
|
|
|
|
content = self._format_message_to_glm_content(msg)
|
|
|
|
|
if history and history[-1]["role"] == content["role"]:
|
|
|
|
|
history[-1]["parts"].extend(content["parts"])
|
|
|
|
|
else:
|
|
|
|
|
history.append(content)
|
|
|
|
|
|
|
|
|
|
# Create a new ClientManager with tenant's API key
|
|
|
|
|
new_client_manager = _ClientManager()
|
|
|
|
|
new_client_manager.configure(api_key=credentials["google_api_key"])
|
|
|
|
|
new_custom_client = new_client_manager.make_client("generative")
|
|
|
|
|
|
|
|
|
|
google_model._client = new_custom_client
|
|
|
|
|
for msg in prompt_messages: # makes message roles strictly alternating
|
|
|
|
|
content = self._format_message_to_glm_content(msg)
|
|
|
|
|
if history and history[-1]["role"] == content["role"]:
|
|
|
|
|
history[-1]["parts"].extend(content["parts"])
|
|
|
|
|
else:
|
|
|
|
|
history.append(content)
|
|
|
|
|
|
|
|
|
|
response = google_model.generate_content(
|
|
|
|
|
contents=history,
|
|
|
|
|
@ -346,7 +321,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
|
|
|
|
content = message.content
|
|
|
|
|
if isinstance(content, list):
|
|
|
|
|
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
|
|
|
|
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
|
|
|
|
|
|
|
|
|
|
if isinstance(message, UserPromptMessage):
|
|
|
|
|
message_text = f"{human_prompt} {content}"
|
|
|
|
|
@ -359,6 +334,44 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
|
|
|
|
return message_text
|
|
|
|
|
|
|
|
|
|
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
|
|
|
|
|
key = f"{message_content.type.value}:{hash(message_content.data)}"
|
|
|
|
|
if redis_client.exists(key):
|
|
|
|
|
try:
|
|
|
|
|
return genai.get_file(redis_client.get(key).decode())
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
|
|
|
|
if message_content.data.startswith("data:"):
|
|
|
|
|
metadata, base64_data = message_content.data.split(",", 1)
|
|
|
|
|
file_content = base64.b64decode(base64_data)
|
|
|
|
|
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
|
|
|
|
temp_file.write(file_content)
|
|
|
|
|
else:
|
|
|
|
|
# only ImagePromptMessageContent and VideoPromptMessageContent has url
|
|
|
|
|
try:
|
|
|
|
|
response = requests.get(message_content.data)
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
if message_content.type is ImagePromptMessageContent:
|
|
|
|
|
prefix = "image/"
|
|
|
|
|
elif message_content.type is VideoPromptMessageContent:
|
|
|
|
|
prefix = "video/"
|
|
|
|
|
mime_type = prefix + message_content.format
|
|
|
|
|
temp_file.write(response.content)
|
|
|
|
|
except Exception as ex:
|
|
|
|
|
raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}")
|
|
|
|
|
temp_file.flush()
|
|
|
|
|
try:
|
|
|
|
|
file = genai.upload_file(path=temp_file.name, mime_type=mime_type)
|
|
|
|
|
while file.state.name == "PROCESSING":
|
|
|
|
|
time.sleep(5)
|
|
|
|
|
file = genai.get_file(file.name)
|
|
|
|
|
# google will delete your upload files in 2 days.
|
|
|
|
|
redis_client.setex(key, 47 * 60 * 60, file.name)
|
|
|
|
|
return file
|
|
|
|
|
finally:
|
|
|
|
|
os.unlink(temp_file.name)
|
|
|
|
|
|
|
|
|
|
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
|
|
|
|
"""
|
|
|
|
|
Format a single message into glm.Content for Google API
|
|
|
|
|
@ -374,28 +387,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
for c in message.content:
|
|
|
|
|
if c.type == PromptMessageContentType.TEXT:
|
|
|
|
|
glm_content["parts"].append(to_part(c.data))
|
|
|
|
|
elif c.type == PromptMessageContentType.IMAGE:
|
|
|
|
|
message_content = cast(ImagePromptMessageContent, c)
|
|
|
|
|
if message_content.data.startswith("data:"):
|
|
|
|
|
metadata, base64_data = c.data.split(",", 1)
|
|
|
|
|
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
|
|
|
|
else:
|
|
|
|
|
# fetch image data from url
|
|
|
|
|
try:
|
|
|
|
|
image_content = requests.get(message_content.data).content
|
|
|
|
|
with Image.open(io.BytesIO(image_content)) as img:
|
|
|
|
|
mime_type = f"image/{img.format.lower()}"
|
|
|
|
|
base64_data = base64.b64encode(image_content).decode("utf-8")
|
|
|
|
|
except Exception as ex:
|
|
|
|
|
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
|
|
|
|
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
|
|
|
|
glm_content["parts"].append(blob)
|
|
|
|
|
elif c.type == PromptMessageContentType.DOCUMENT:
|
|
|
|
|
message_content = cast(DocumentPromptMessageContent, c)
|
|
|
|
|
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
|
|
|
|
|
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
|
|
|
|
|
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
|
|
|
|
|
glm_content["parts"].append(blob)
|
|
|
|
|
else:
|
|
|
|
|
glm_content["parts"].append(self._upload_file_content_to_google(c))
|
|
|
|
|
|
|
|
|
|
return glm_content
|
|
|
|
|
elif isinstance(message, AssistantPromptMessage):
|
|
|
|
|
|