Feat/add azure dalle tool (#2276)
Co-authored-by: lux@njuelectronics.com <lux@njuelectronics.com> Co-authored-by: crazywoola <427733928@qq.com>pull/2281/head
parent
76cc19f525
commit
c97b7f6748
Binary file not shown.
|
After Width: | Height: | Size: 153 KiB |
@ -0,0 +1,23 @@
|
|||||||
|
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||||
|
from core.tools.provider.builtin.azuredalle.tools.dalle3 import DallE3Tool
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
class AzureDALLEProvider(BuiltinToolProviderController):
|
||||||
|
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
|
||||||
|
try:
|
||||||
|
DallE3Tool().fork_tool_runtime(
|
||||||
|
meta={
|
||||||
|
"credentials": credentials,
|
||||||
|
}
|
||||||
|
).invoke(
|
||||||
|
user_id='',
|
||||||
|
tool_paramters={
|
||||||
|
"prompt": "cute girl, blue eyes, white hair, anime style",
|
||||||
|
"size": "square",
|
||||||
|
"n": 1
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolProviderCredentialValidationError(str(e))
|
||||||
@ -0,0 +1,66 @@
|
|||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
from base64 import b64decode
|
||||||
|
from os.path import join
|
||||||
|
|
||||||
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
|
class DallE3Tool(BuiltinTool):
|
||||||
|
def _invoke(self,
|
||||||
|
user_id: str,
|
||||||
|
tool_paramters: Dict[str, Any],
|
||||||
|
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
||||||
|
"""
|
||||||
|
invoke tools
|
||||||
|
"""
|
||||||
|
client = AzureOpenAI(
|
||||||
|
api_version=self.runtime.credentials['azure_openai_api_version'],
|
||||||
|
azure_endpoint=self.runtime.credentials['azure_openai_base_url'],
|
||||||
|
api_key=self.runtime.credentials['azure_openai_api_key'],
|
||||||
|
)
|
||||||
|
|
||||||
|
SIZE_MAPPING = {
|
||||||
|
'square': '1024x1024',
|
||||||
|
'vertical': '1024x1792',
|
||||||
|
'horizontal': '1792x1024',
|
||||||
|
}
|
||||||
|
|
||||||
|
# prompt
|
||||||
|
prompt = tool_paramters.get('prompt', '')
|
||||||
|
if not prompt:
|
||||||
|
return self.create_text_message('Please input prompt')
|
||||||
|
# get size
|
||||||
|
size = SIZE_MAPPING[tool_paramters.get('size', 'square')]
|
||||||
|
# get n
|
||||||
|
n = tool_paramters.get('n', 1)
|
||||||
|
# get quality
|
||||||
|
quality = tool_paramters.get('quality', 'standard')
|
||||||
|
if quality not in ['standard', 'hd']:
|
||||||
|
return self.create_text_message('Invalid quality')
|
||||||
|
# get style
|
||||||
|
style = tool_paramters.get('style', 'vivid')
|
||||||
|
if style not in ['natural', 'vivid']:
|
||||||
|
return self.create_text_message('Invalid style')
|
||||||
|
|
||||||
|
# call openapi dalle3
|
||||||
|
model=self.runtime.credentials['azure_openai_api_model_name']
|
||||||
|
response = client.images.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
size=size,
|
||||||
|
n=n,
|
||||||
|
style=style,
|
||||||
|
quality=quality,
|
||||||
|
response_format='b64_json'
|
||||||
|
)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for image in response.data:
|
||||||
|
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||||
|
meta={ 'mime_type': 'image/png' },
|
||||||
|
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||||
|
|
||||||
|
return result
|
||||||
Loading…
Reference in New Issue