You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gcgj-dify-1.7.0/api/services/image_generation_service.py

92 lines
3.2 KiB
Python

from enum import Enum
from configs import dify_config
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db
from libs.helper import RateLimiter
from libs.infinite_scroll_pagination import MultiPagePagination
from models.model import EndUser, UserGeneratedImage
from tasks.image_generation_task import generate_image_task
# define string enum for content_type
class ContentType(str, Enum):
SELF_MESSAGE = "self_message"
SUMMARY_ADVICE = "summary_advice"
DEFAULT_IMAGE_EXTENSION = ".png"
class ImageGenerationService:
generate_image_rate_limiter = RateLimiter(
prefix="generate_image_rate_limit", max_attempts=dify_config.IMAGE_GENERATION_DAILY_LIMIT, time_window=86400 * 1
)
@staticmethod
def generate_image(end_user: EndUser, content_type: ContentType) -> str:
"""
Initiates asynchronous image generation process and creates a pending image record
Args:
end_user: End user object
content_type: Type of content to generate
Returns:
The ID of the created UserGeneratedImage entity that will be updated by the task
"""
# Check if rate limited before submitting task
if ImageGenerationService.generate_image_rate_limiter.is_rate_limited(end_user.id):
raise Exception("Image generation rate limit exceeded")
# Create a pending UserGeneratedImage entity
user_generated_image = UserGeneratedImage(
app_id=end_user.app_id,
end_user_id=end_user.id,
content_type=content_type,
status="pending", # Set initial status to pending
)
db.session.add(user_generated_image)
db.session.commit()
# Get the generated ID for tracking
image_id = str(user_generated_image.id)
# Submit the task asynchronously with the image_id
generate_image_task.delay(end_user_id=str(end_user.id), content_type=content_type, image_id=image_id)
# Return the image ID as a reference for status checking
return image_id
@staticmethod
def pagination_image_list(end_user: EndUser, limit: int, offset: int) -> MultiPagePagination:
query = (
db.session.query(UserGeneratedImage)
.filter(UserGeneratedImage.app_id == end_user.app_id, UserGeneratedImage.end_user_id == end_user.id)
.order_by(UserGeneratedImage.created_at.desc())
)
total_count = query.count()
images = query.limit(limit).offset(offset).all()
# sign file with file_id to get a temporary url
for image in images:
if image.file_id:
image.image_url = ToolFileManager.sign_file(image.file_id, DEFAULT_IMAGE_EXTENSION)
return MultiPagePagination(data=images, total=total_count)
@staticmethod
def get_image_by_id(image_id: str) -> UserGeneratedImage:
image = db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first()
if image is None:
raise Exception("Image not found")
# sign file with file_id to get a temporary url
if image.file_id:
image.image_url = ToolFileManager.sign_file(image.file_id, DEFAULT_IMAGE_EXTENSION)
return image