|
|
|
@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject
|
|
|
|
from core.tools.errors import ToolProviderCredentialValidationError
|
|
|
|
from core.tools.errors import ToolProviderCredentialValidationError
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Union
|
|
|
|
from typing import Any, Dict, List, Union
|
|
|
|
from httpx import post
|
|
|
|
from httpx import post, get
|
|
|
|
from os.path import join
|
|
|
|
from os.path import join
|
|
|
|
from base64 import b64decode, b64encode
|
|
|
|
from base64 import b64decode, b64encode
|
|
|
|
from PIL import Image
|
|
|
|
from PIL import Image
|
|
|
|
@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = {
|
|
|
|
"alwayson_scripts": {}
|
|
|
|
"alwayson_scripts": {}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusionTool(BuiltinTool):
|
|
|
|
class StableDiffusionTool(BuiltinTool):
|
|
|
|
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
|
|
|
|
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
|
|
|
|
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
|
|
|
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
|
|
|
@ -137,6 +138,30 @@ class StableDiffusionTool(BuiltinTool):
|
|
|
|
height=height,
|
|
|
|
height=height,
|
|
|
|
steps=steps)
|
|
|
|
steps=steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_models(self) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
validate models
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
base_url = self.runtime.credentials.get('base_url', None)
|
|
|
|
|
|
|
|
if not base_url:
|
|
|
|
|
|
|
|
raise ToolProviderCredentialValidationError('Please input base_url')
|
|
|
|
|
|
|
|
model = self.runtime.credentials.get('model', None)
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
|
|
raise ToolProviderCredentialValidationError('Please input model')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
|
|
|
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
|
|
|
raise ToolProviderCredentialValidationError('Failed to get models')
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
models = [d['model_name'] for d in response.json()]
|
|
|
|
|
|
|
|
if len([d for d in models if d == model]) > 0:
|
|
|
|
|
|
|
|
return self.create_text_message(json.dumps(models))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ToolProviderCredentialValidationError(f'model {model} does not exist')
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
|
|
|
|
|
|
|
|
|
|
|
|
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
|
|
|
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
|
|
|
prompt: str, negative_prompt: str,
|
|
|
|
prompt: str, negative_prompt: str,
|
|
|
|
width: int, height: int, steps: int) \
|
|
|
|
width: int, height: int, steps: int) \
|
|
|
|
@ -211,7 +236,6 @@ class StableDiffusionTool(BuiltinTool):
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
return self.create_text_message('Failed to generate image')
|
|
|
|
return self.create_text_message('Failed to generate image')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_runtime_parameters(self) -> List[ToolParameter]:
|
|
|
|
def get_runtime_parameters(self) -> List[ToolParameter]:
|
|
|
|
parameters = [
|
|
|
|
parameters = [
|
|
|
|
ToolParameter(name='prompt',
|
|
|
|
ToolParameter(name='prompt',
|
|
|
|
|