@ -1,3 +1,5 @@
import base64
import io
import json
import random
import uuid
@ -6,45 +8,48 @@ import httpx
from websocket import WebSocket
from yarl import URL
from core . file . file_manager import _get_encoded_string
from core . file . models import File
class ComfyUiClient :
def __init__ ( self , base_url : str ) :
self . base_url = URL ( base_url )
def get_history ( self , prompt_id : str ) :
def get_history ( self , prompt_id : str ) - > dict :
res = httpx . get ( str ( self . base_url / " history " ) , params = { " prompt_id " : prompt_id } )
history = res . json ( ) [ prompt_id ]
return history
def get_image ( self , filename : str , subfolder : str , folder_type : str ) :
def get_image ( self , filename : str , subfolder : str , folder_type : str ) - > bytes :
response = httpx . get (
str ( self . base_url / " view " ) ,
params = { " filename " : filename , " subfolder " : subfolder , " type " : folder_type } ,
)
return response . content
def upload_image ( self , input_path : str , name : str , image_type : str = " input " , overwrite : bool = False ) :
# plan to support img2img in dify 0.10.0
with open ( input_path , " rb " ) as file :
files = { " image " : ( name , file , " image/png " ) }
data = { " type " : image_type , " overwrite " : str ( overwrite ) . lower ( ) }
res = httpx . post ( str ( self . base_url / " upload/image " ) , data = data , files = files )
return res
def upload_image ( self , image_file : File ) - > dict :
image_content = base64 . b64decode ( _get_encoded_string ( image_file ) )
file = io . BytesIO ( image_content )
files = { " image " : ( image_file . filename , file , image_file . mime_type ) , " overwrite " : " true " }
res = httpx . post ( str ( self . base_url / " upload/image " ) , files = files )
return res . json ( )
def queue_prompt ( self , client_id : str , prompt : dict ) :
def queue_prompt ( self , client_id : str , prompt : dict ) - > str :
res = httpx . post ( str ( self . base_url / " prompt " ) , json = { " client_id " : client_id , " prompt " : prompt } )
prompt_id = res . json ( ) [ " prompt_id " ]
return prompt_id
def open_websocket_connection ( self ) :
def open_websocket_connection ( self ) - > tuple [ WebSocket , str ] :
client_id = str ( uuid . uuid4 ( ) )
ws = WebSocket ( )
ws_address = f " ws:// { self . base_url . authority } /ws?clientId= { client_id } "
ws . connect ( ws_address )
return ws , client_id
def set_prompt ( self , origin_prompt : dict , positive_prompt : str , negative_prompt : str = " " ) :
def set_prompt (
self , origin_prompt : dict , positive_prompt : str , negative_prompt : str = " " , image_name : str = " "
) - > dict :
"""
find the first KSampler , then can find the prompt node through it .
"""
@ -58,6 +63,10 @@ class ComfyUiClient:
if negative_prompt != " " :
negative_input_id = prompt . get ( k_sampler ) [ " inputs " ] [ " negative " ] [ 0 ]
prompt . get ( negative_input_id ) [ " inputs " ] [ " text " ] = negative_prompt
if image_name != " " :
image_loader = [ key for key , value in id_to_class_type . items ( ) if value == " LoadImage " ] [ 0 ]
prompt . get ( image_loader ) [ " inputs " ] [ " image " ] = image_name
return prompt
def track_progress ( self , prompt : dict , ws : WebSocket , prompt_id : str ) :
@ -89,7 +98,7 @@ class ComfyUiClient:
else :
continue
def generate_image_by_prompt ( self , prompt : dict ) :
def generate_image_by_prompt ( self , prompt : dict ) - > list [ bytes ] :
try :
ws , client_id = self . open_websocket_connection ( )
prompt_id = self . queue_prompt ( client_id , prompt )