Feat/add zhipu CogView 3 tool (#6210)
parent
a7b33b55e8
commit
07add06c59
Binary file not shown.
|
After Width: | Height: | Size: 22 KiB |
@ -0,0 +1,27 @@
|
|||||||
|
""" Provide the input parameters type for the cogview provider class """
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
from core.tools.provider.builtin.cogview.tools.cogview3 import CogView3Tool
|
||||||
|
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||||
|
|
||||||
|
|
||||||
|
class COGVIEWProvider(BuiltinToolProviderController):
|
||||||
|
""" cogview provider """
|
||||||
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
|
try:
|
||||||
|
CogView3Tool().fork_tool_runtime(
|
||||||
|
runtime={
|
||||||
|
"credentials": credentials,
|
||||||
|
}
|
||||||
|
).invoke(
|
||||||
|
user_id='',
|
||||||
|
tool_parameters={
|
||||||
|
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
|
||||||
|
"size": "square",
|
||||||
|
"n": 1
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolProviderCredentialValidationError(str(e)) from e
|
||||||
|
|
||||||
@ -0,0 +1,61 @@
|
|||||||
|
identity:
|
||||||
|
author: Waffle
|
||||||
|
name: cogview
|
||||||
|
label:
|
||||||
|
en_US: CogView
|
||||||
|
zh_Hans: CogView 绘画
|
||||||
|
pt_BR: CogView
|
||||||
|
description:
|
||||||
|
en_US: CogView art
|
||||||
|
zh_Hans: CogView 绘画
|
||||||
|
pt_BR: CogView art
|
||||||
|
icon: icon.png
|
||||||
|
tags:
|
||||||
|
- image
|
||||||
|
- productivity
|
||||||
|
credentials_for_provider:
|
||||||
|
zhipuai_api_key:
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: ZhipuAI API key
|
||||||
|
zh_Hans: ZhipuAI API key
|
||||||
|
pt_BR: ZhipuAI API key
|
||||||
|
help:
|
||||||
|
en_US: Please input your ZhipuAI API key
|
||||||
|
zh_Hans: 请输入你的 ZhipuAI API key
|
||||||
|
pt_BR: Please input your ZhipuAI API key
|
||||||
|
placeholder:
|
||||||
|
en_US: Please input your ZhipuAI API key
|
||||||
|
zh_Hans: 请输入你的 ZhipuAI API key
|
||||||
|
pt_BR: Please input your ZhipuAI API key
|
||||||
|
zhipuai_organizaion_id:
|
||||||
|
type: text-input
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: ZhipuAI organization ID
|
||||||
|
zh_Hans: ZhipuAI organization ID
|
||||||
|
pt_BR: ZhipuAI organization ID
|
||||||
|
help:
|
||||||
|
en_US: Please input your ZhipuAI organization ID
|
||||||
|
zh_Hans: 请输入你的 ZhipuAI organization ID
|
||||||
|
pt_BR: Please input your ZhipuAI organization ID
|
||||||
|
placeholder:
|
||||||
|
en_US: Please input your ZhipuAI organization ID
|
||||||
|
zh_Hans: 请输入你的 ZhipuAI organization ID
|
||||||
|
pt_BR: Please input your ZhipuAI organization ID
|
||||||
|
zhipuai_base_url:
|
||||||
|
type: text-input
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: ZhipuAI base URL
|
||||||
|
zh_Hans: ZhipuAI base URL
|
||||||
|
pt_BR: ZhipuAI base URL
|
||||||
|
help:
|
||||||
|
en_US: Please input your ZhipuAI base URL
|
||||||
|
zh_Hans: 请输入你的 ZhipuAI base URL
|
||||||
|
pt_BR: Please input your ZhipuAI base URL
|
||||||
|
placeholder:
|
||||||
|
en_US: Please input your ZhipuAI base URL
|
||||||
|
zh_Hans: 请输入你的 ZhipuAI base URL
|
||||||
|
pt_BR: Please input your ZhipuAI base URL
|
||||||
@ -0,0 +1,69 @@
|
|||||||
|
import random
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
||||||
|
class CogView3Tool(BuiltinTool):
|
||||||
|
""" CogView3 Tool """
|
||||||
|
|
||||||
|
def _invoke(self,
|
||||||
|
user_id: str,
|
||||||
|
tool_parameters: dict[str, Any]
|
||||||
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
|
"""
|
||||||
|
Invoke CogView3 tool
|
||||||
|
"""
|
||||||
|
client = ZhipuAI(
|
||||||
|
base_url=self.runtime.credentials['zhipuai_base_url'],
|
||||||
|
api_key=self.runtime.credentials['zhipuai_api_key'],
|
||||||
|
)
|
||||||
|
size_mapping = {
|
||||||
|
'square': '1024x1024',
|
||||||
|
'vertical': '1024x1792',
|
||||||
|
'horizontal': '1792x1024',
|
||||||
|
}
|
||||||
|
# prompt
|
||||||
|
prompt = tool_parameters.get('prompt', '')
|
||||||
|
if not prompt:
|
||||||
|
return self.create_text_message('Please input prompt')
|
||||||
|
# get size
|
||||||
|
print(tool_parameters.get('prompt', 'square'))
|
||||||
|
size = size_mapping[tool_parameters.get('size', 'square')]
|
||||||
|
# get n
|
||||||
|
n = tool_parameters.get('n', 1)
|
||||||
|
# get quality
|
||||||
|
quality = tool_parameters.get('quality', 'standard')
|
||||||
|
if quality not in ['standard', 'hd']:
|
||||||
|
return self.create_text_message('Invalid quality')
|
||||||
|
# get style
|
||||||
|
style = tool_parameters.get('style', 'vivid')
|
||||||
|
if style not in ['natural', 'vivid']:
|
||||||
|
return self.create_text_message('Invalid style')
|
||||||
|
# set extra body
|
||||||
|
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
||||||
|
extra_body = {'seed': seed_id}
|
||||||
|
response = client.images.generations(
|
||||||
|
prompt=prompt,
|
||||||
|
model="cogview-3",
|
||||||
|
size=size,
|
||||||
|
n=n,
|
||||||
|
extra_body=extra_body,
|
||||||
|
style=style,
|
||||||
|
quality=quality,
|
||||||
|
response_format='b64_json'
|
||||||
|
)
|
||||||
|
result = []
|
||||||
|
for image in response.data:
|
||||||
|
result.append(self.create_image_message(image=image.url))
|
||||||
|
result.append(self.create_text_message(
|
||||||
|
f'\nGenerate image source to Seed ID: {seed_id}'))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_random_id(length=8):
|
||||||
|
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||||
|
random_id = ''.join(random.choices(characters, k=length))
|
||||||
|
return random_id
|
||||||
Loading…
Reference in New Issue