Feat(tool): fal ai flux image generation (#10606)
parent
bddecba9ed
commit
2a4783307a
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg version="1.1" xmlns="http://www.w3.org/2000/svg" width="32" height="32">
|
||||
<path d="M0 0 C3.96 0 7.92 0 12 0 C12.4125 0.928125 12.825 1.85625 13.25 2.8125 C15.56104487 7.02190315 17.49701732 8.49900577 22 10 C22 13.96 22 17.92 22 22 C21.071875 22.4125 20.14375 22.825 19.1875 23.25 C14.97809685 25.56104487 13.50099423 27.49701732 12 32 C8.04 32 4.08 32 0 32 C-0.4125 31.071875 -0.825 30.14375 -1.25 29.1875 C-3.56104487 24.97809685 -5.49701732 23.50099423 -10 22 C-10 18.04 -10 14.08 -10 10 C-9.071875 9.5875 -8.14375 9.175 -7.1875 8.75 C-2.97809685 6.43895513 -1.50099423 4.50298268 0 0 Z M-2 11 C-3.42662219 13.85324437 -3.31033868 15.83454549 -3 19 C-1.20006226 21.69990662 0.083773 23.5418865 3 25 C7.1364408 25.56406011 8.76045933 25.14638597 12.375 22.9375 C15.26054626 20.20817124 15.26054626 20.20817124 15.6875 16.5625 C14.76325283 11.77321919 13.68514918 10.2147046 10 7 C4.54838272 6.02649691 1.87056683 7.12943317 -2 11 Z " fill="#EC0648" transform="translate(10,0)"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.0 KiB |
@ -0,0 +1,20 @@
|
||||
import requests
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class FalProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
url = "https://fal.run/fal-ai/flux/dev"
|
||||
headers = {
|
||||
"Authorization": f"Key {credentials.get('fal_api_key')}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {"prompt": "Cat"}
|
||||
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
if response.status_code == 401:
|
||||
raise ToolProviderCredentialValidationError("FAL API key is invalid")
|
||||
elif response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError(f"FAL API key validation failed: {response.text}")
|
||||
@ -0,0 +1,21 @@
|
||||
identity:
|
||||
author: Kalo Chin
|
||||
name: fal
|
||||
label:
|
||||
en_US: FAL
|
||||
zh_CN: FAL
|
||||
description:
|
||||
en_US: The image generation API provided by FAL.
|
||||
zh_CN: FAL 提供的图像生成 API。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- image
|
||||
credentials_for_provider:
|
||||
fal_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: FAL API Key
|
||||
placeholder:
|
||||
en_US: Please input your FAL API key
|
||||
url: https://fal.ai/dashboard/keys
|
||||
@ -0,0 +1,46 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Flux11ProTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"Authorization": f"Key {self.runtime.credentials['fal_api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
sanitized_prompt = prompt.replace("\\", "") # Remove backslashes from the prompt which may cause errors
|
||||
|
||||
payload = {
|
||||
"prompt": sanitized_prompt,
|
||||
"image_size": tool_parameters.get("image_size", "landscape_4_3"),
|
||||
"seed": tool_parameters.get("seed"),
|
||||
"sync_mode": tool_parameters.get("sync_mode", False),
|
||||
"num_images": tool_parameters.get("num_images", 1),
|
||||
"enable_safety_checker": tool_parameters.get("enable_safety_checker", True),
|
||||
"safety_tolerance": tool_parameters.get("safety_tolerance", "2"),
|
||||
}
|
||||
|
||||
url = "https://fal.run/fal-ai/flux-pro/v1.1"
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response: {response.text}")
|
||||
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
|
||||
for image_info in res.get("images", []):
|
||||
image_url = image_info.get("url")
|
||||
if image_url:
|
||||
result.append(self.create_image_message(image=image_url, save_as=self.VariableKey.IMAGE.value))
|
||||
|
||||
return result
|
||||
@ -0,0 +1,47 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Flux11ProUltraTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"Authorization": f"Key {self.runtime.credentials['fal_api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
sanitized_prompt = prompt.replace("\\", "") # Remove backslashes from the prompt which may cause errors
|
||||
|
||||
payload = {
|
||||
"prompt": sanitized_prompt,
|
||||
"seed": tool_parameters.get("seed"),
|
||||
"sync_mode": tool_parameters.get("sync_mode", False),
|
||||
"num_images": tool_parameters.get("num_images", 1),
|
||||
"enable_safety_checker": tool_parameters.get("enable_safety_checker", True),
|
||||
"safety_tolerance": str(tool_parameters.get("safety_tolerance", "2")),
|
||||
"aspect_ratio": tool_parameters.get("aspect_ratio", "16:9"),
|
||||
"raw": tool_parameters.get("raw", False),
|
||||
}
|
||||
|
||||
url = "https://fal.run/fal-ai/flux-pro/v1.1-ultra"
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response: {response.text}")
|
||||
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
|
||||
for image_info in res.get("images", []):
|
||||
image_url = image_info.get("url")
|
||||
if image_url:
|
||||
result.append(self.create_image_message(image=image_url, save_as=self.VariableKey.IMAGE.value))
|
||||
|
||||
return result
|
||||
@ -0,0 +1,47 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Flux1DevTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"Authorization": f"Key {self.runtime.credentials['fal_api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
sanitized_prompt = prompt.replace("\\", "") # Remove backslashes from the prompt which may cause errors
|
||||
|
||||
payload = {
|
||||
"prompt": sanitized_prompt,
|
||||
"image_size": tool_parameters.get("image_size", "landscape_4_3"),
|
||||
"num_inference_steps": tool_parameters.get("num_inference_steps", 28),
|
||||
"guidance_scale": tool_parameters.get("guidance_scale", 3.5),
|
||||
"seed": tool_parameters.get("seed"),
|
||||
"num_images": tool_parameters.get("num_images", 1),
|
||||
"enable_safety_checker": tool_parameters.get("enable_safety_checker", True),
|
||||
"sync_mode": tool_parameters.get("sync_mode", False),
|
||||
}
|
||||
|
||||
url = "https://fal.run/fal-ai/flux/dev"
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response: {response.text}")
|
||||
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
|
||||
for image_info in res.get("images", []):
|
||||
image_url = image_info.get("url")
|
||||
if image_url:
|
||||
result.append(self.create_image_message(image=image_url, save_as=self.VariableKey.IMAGE.value))
|
||||
|
||||
return result
|
||||
@ -0,0 +1,47 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Flux1ProNewTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"Authorization": f"Key {self.runtime.credentials['fal_api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
sanitized_prompt = prompt.replace("\\", "") # Remove backslashes that may cause errors
|
||||
|
||||
payload = {
|
||||
"prompt": sanitized_prompt,
|
||||
"image_size": tool_parameters.get("image_size", "landscape_4_3"),
|
||||
"num_inference_steps": tool_parameters.get("num_inference_steps", 28),
|
||||
"guidance_scale": tool_parameters.get("guidance_scale", 3.5),
|
||||
"seed": tool_parameters.get("seed"),
|
||||
"num_images": tool_parameters.get("num_images", 1),
|
||||
"safety_tolerance": tool_parameters.get("safety_tolerance", "2"),
|
||||
"sync_mode": tool_parameters.get("sync_mode", False),
|
||||
}
|
||||
|
||||
url = "https://fal.run/fal-ai/flux-pro/new"
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response: {response.text}")
|
||||
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
|
||||
for image_info in res.get("images", []):
|
||||
image_url = image_info.get("url")
|
||||
if image_url:
|
||||
result.append(self.create_image_message(image=image_url, save_as=self.VariableKey.IMAGE.value))
|
||||
|
||||
return result
|
||||
Loading…
Reference in New Issue