finsh async image generation

pull/21891/head
ytqh 1 year ago
parent d85267d5e1
commit 40eef9822c

@ -1,5 +1,3 @@
from datetime import datetime, timedelta
from configs import dify_config from configs import dify_config
from controllers.service_api_with_auth import api from controllers.service_api_with_auth import api
from controllers.service_api_with_auth.app.error import NotChatAppError, NotEnoughMessageCountError from controllers.service_api_with_auth.app.error import NotChatAppError, NotEnoughMessageCountError
@ -48,7 +46,7 @@ class ImageGenerateApi(Resource):
example: success example: success
image_id: image_id:
type: string type: string
description: ID of the generated image, futher to fetch the image details and status description: ID of the generated image (initially in pending state)
example: 123e4567-e89b-12d3-a456-426614174000 example: 123e4567-e89b-12d3-a456-426614174000
message: message:
type: string type: string
@ -75,15 +73,20 @@ class ImageGenerateApi(Resource):
raise NotEnoughMessageCountError() raise NotEnoughMessageCountError()
try: try:
# Use the service to generate the image # Create a pending image entity and start the generation task
image_id = ImageGenerationService.generate_image( image_id = ImageGenerationService.generate_image(
end_user=end_user, end_user=end_user,
content_type=content_type, content_type=content_type,
) )
return {"result": "success", "message": "Image generated successfully.", "image_id": image_id} # Return the image ID for status checking
except Exception: return {
raise InternalServerError("Failed to generate image") "result": "success",
"message": "Image generation started. Check status with the image ID.",
"image_id": image_id,
}
except Exception as e:
raise InternalServerError(f"Failed to generate image: {e}")
class ImageListApi(Resource): class ImageListApi(Resource):

@ -34,6 +34,7 @@ end_user_image_fields = {
"image_url": fields.String, "image_url": fields.String,
"content_type": fields.String, "content_type": fields.String,
"text_content": fields.String, "text_content": fields.String,
"error_message": fields.String,
"created_at": TimestampField, "created_at": TimestampField,
} }

@ -0,0 +1,39 @@
"""update image table
Revision ID: 329938947502
Revises: a215b993fa58
Create Date: 2025-03-31 21:12:55.412130
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '329938947502'
down_revision = 'a215b993fa58'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.add_column(sa.Column('error_message', sa.Text(), nullable=True))
batch_op.alter_column('workflow_run_id',
existing_type=sa.UUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.alter_column('workflow_run_id',
existing_type=sa.UUID(),
nullable=False)
batch_op.drop_column('error_message')
# ### end Alembic commands ###

@ -0,0 +1,33 @@
"""update status
Revision ID: 382f0cad6652
Revises: 329938947502
Create Date: 2025-03-31 21:15:32.929703
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '382f0cad6652'
down_revision = '329938947502'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.add_column(sa.Column('status', sa.String(length=20), server_default='pending', nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.drop_column('status')
# ### end Alembic commands ###

@ -1866,11 +1866,15 @@ class UserGeneratedImage(db.Model): # type: ignore[name-defined]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
end_user_id = db.Column(StringUUID, nullable=False) end_user_id = db.Column(StringUUID, nullable=False)
workflow_run_id = db.Column(StringUUID, nullable=False) # related generation id workflow_run_id = db.Column(StringUUID, nullable=True) # related generation id (nullable for pending status)
content_type = db.Column(db.String(255), nullable=False) # 'self_message' or 'summary_advice' content_type = db.Column(db.String(255), nullable=False) # 'self_message' or 'summary_advice'
image_url = db.Column(db.Text, nullable=True) image_url = db.Column(db.Text, nullable=True)
text_content = db.Column(db.Text, nullable=True) text_content = db.Column(db.Text, nullable=True)
raw_content = db.Column(db.JSON, nullable=True) # save raw llm outputs raw_content = db.Column(db.JSON, nullable=True) # save raw llm outputs
status = db.Column(
db.String(20), nullable=False, server_default="pending"
) # pending, processing, completed, failed
error_message = db.Column(db.Text, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
updated_at = db.Column( updated_at = db.Column(
db.DateTime, nullable=False, server_default=db.func.current_timestamp(), onupdate=db.func.current_timestamp() db.DateTime, nullable=False, server_default=db.func.current_timestamp(), onupdate=db.func.current_timestamp()
@ -1879,23 +1883,3 @@ class UserGeneratedImage(db.Model): # type: ignore[name-defined]
@property @property
def end_user(self): def end_user(self):
return db.session.query(EndUser).filter(EndUser.id == self.end_user_id).first() return db.session.query(EndUser).filter(EndUser.id == self.end_user_id).first()
@property
def status(self):
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
if workflow_run is None:
return WorkflowRunStatus.FAILED.value
return workflow_run.status
def refresh_status(self):
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
if workflow_run is None:
return
if workflow_run.status == WorkflowRunStatus.SUCCEEDED and self.image_url is None:
self.image_url = workflow_run.outputs.get("image_url")
self.text_content = workflow_run.outputs.get("text_content")
db.session.commit()

@ -8,6 +8,7 @@ from libs.helper import RateLimiter
from libs.infinite_scroll_pagination import MultiPagePagination from libs.infinite_scroll_pagination import MultiPagePagination
from models.model import App, EndUser, Message, UserGeneratedImage from models.model import App, EndUser, Message, UserGeneratedImage
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from tasks.image_generation_task import generate_image_task
# define string enum for content_type # define string enum for content_type
@ -24,83 +25,39 @@ class ImageGenerationService:
@staticmethod @staticmethod
def generate_image(end_user: EndUser, content_type: ContentType) -> str: def generate_image(end_user: EndUser, content_type: ContentType) -> str:
"""
Initiates asynchronous image generation process and creates a pending image record
Args:
end_user: End user object
content_type: Type of content to generate
Returns:
The ID of the created UserGeneratedImage entity that will be updated by the task
"""
# Check if rate limited before submitting task
if ImageGenerationService.generate_image_rate_limiter.is_rate_limited(end_user.id): if ImageGenerationService.generate_image_rate_limiter.is_rate_limited(end_user.id):
raise Exception("Image generation rate limit exceeded") raise Exception("Image generation rate limit exceeded")
if dify_config.IMAGE_GENERATION_APP_ID is None: # Create a pending UserGeneratedImage entity
raise Exception("Image generation app id is not set")
image_generation_app_model = App.query.filter(App.id == dify_config.IMAGE_GENERATION_APP_ID).first()
if image_generation_app_model is None:
raise Exception("Image generation app model is not found")
user_profile = end_user.extra_profile
recent_messages = (
db.session.query(Message)
.filter(Message.app_id == end_user.app_id, Message.from_end_user_id == end_user.id)
.order_by(Message.created_at.desc())
.limit(10)
.all()
)
recent_messages = [f"user: {message.query}\n\nassistant: {message.answer}" for message in recent_messages]
args = {
"inputs": {
"user_profile": json.dumps(user_profile),
"recent_messages": "\n\n".join(recent_messages),
"image_type": content_type,
}
}
try:
response = AppGenerateService.generate(
app_model=image_generation_app_model,
user=end_user,
args=args,
invoke_from=InvokeFrom.SCHEDULER,
streaming=False,
)
if not isinstance(response, dict):
raise Exception("Failed to generate image")
# load workflow id and save it to db for futher fetch image status
workflow_run_id = response.get("workflow_run_id")
raw_content = response.get("data", {}).get("outputs", {})
# parse url from response.data.outputs
image_objs = raw_content.get("files")
url = None
for image_obj in image_objs:
if image_obj.get("type") == "image":
url = image_obj.get("url")
break
if url is None:
raise Exception("Failed to generate image")
text_content = raw_content.get("text")
user_generated_image = UserGeneratedImage( user_generated_image = UserGeneratedImage(
app_id=end_user.app_id, app_id=end_user.app_id,
end_user_id=end_user.id, end_user_id=end_user.id,
workflow_run_id=workflow_run_id,
content_type=content_type, content_type=content_type,
image_url=url, status="pending", # Set initial status to pending
text_content=text_content,
raw_content=raw_content,
) )
db.session.add(user_generated_image) db.session.add(user_generated_image)
db.session.commit() db.session.commit()
return user_generated_image.id # Get the generated ID for tracking
image_id = str(user_generated_image.id)
# Submit the task asynchronously with the image_id
generate_image_task.delay(end_user_id=str(end_user.id), content_type=content_type, image_id=image_id)
except Exception as e: # Return the image ID as a reference for status checking
raise Exception(f"Failed to generate image: {e}") return image_id
@staticmethod @staticmethod
def pagination_image_list(end_user: EndUser, limit: int, offset: int) -> MultiPagePagination: def pagination_image_list(end_user: EndUser, limit: int, offset: int) -> MultiPagePagination:

@ -0,0 +1,152 @@
import json
import logging
import time
import click
from celery import shared_task # type: ignore
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.model import App, EndUser, Message, UserGeneratedImage
from services.app_generate_service import AppGenerateService
@shared_task(queue="generation")
def generate_image_task(
end_user_id: str,
content_type: str,
image_id: str,
) -> str:
"""
Asynchronously generate an image based on the end user's conversation data and update the existing UserGeneratedImage record
Args:
end_user_id: End user ID
content_type: Type of content to generate (self_message or summary_advice)
app_id: The app ID of the end user
image_id: ID of the existing pending UserGeneratedImage entity to update
Returns:
The ID of the updated image record
Usage: generate_image_task.delay(end_user_id, content_type, app_id, image_id)
"""
logging.info(click.style(f"Starting image generation for user {end_user_id}, image_id: {image_id}", fg="green"))
start_at = time.perf_counter()
try:
# Retrieve models for processing
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
raise Exception(f"End user {end_user_id} not found")
# Get the existing UserGeneratedImage entity
user_generated_image = db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first()
if not user_generated_image:
raise Exception(f"UserGeneratedImage {image_id} not found")
# Update status to processing
user_generated_image.status = "processing"
db.session.commit()
# Get image generation app
if dify_config.IMAGE_GENERATION_APP_ID is None:
user_generated_image.status = "failed"
user_generated_image.error_message = "Image generation app id is not set"
db.session.commit()
raise Exception("Image generation app id is not set")
image_generation_app_model = App.query.filter(App.id == dify_config.IMAGE_GENERATION_APP_ID).first()
if image_generation_app_model is None:
user_generated_image.status = "failed"
user_generated_image.error_message = "Image generation app model is not found"
db.session.commit()
raise Exception("Image generation app model is not found")
# Get user profile and recent messages
user_profile = end_user.extra_profile
recent_messages = (
db.session.query(Message)
.filter(Message.app_id == end_user.app_id, Message.from_end_user_id == end_user.id)
.order_by(Message.created_at.desc())
.limit(10)
.all()
)
recent_messages = [f"user: {message.query}\n\nassistant: {message.answer}" for message in recent_messages]
# Prepare arguments for generation
args = {
"inputs": {
"user_profile": json.dumps(user_profile),
"recent_messages": "\n\n".join(recent_messages),
"image_type": content_type,
}
}
# Generate image through app service
response = AppGenerateService.generate(
app_model=image_generation_app_model,
user=end_user,
args=args,
invoke_from=InvokeFrom.SCHEDULER,
streaming=False,
)
if not isinstance(response, dict):
user_generated_image.status = "failed"
user_generated_image.error_message = "Failed to generate image"
db.session.commit()
raise Exception("Failed to generate image")
# Extract workflow run ID and content
workflow_run_id = response.get("workflow_run_id")
raw_content = response.get("data", {}).get("outputs", {})
# Parse URL from response
image_objs = raw_content.get("files")
url = None
for image_obj in image_objs:
if image_obj.get("type") == "image":
url = image_obj.get("url")
break
if url is None:
user_generated_image.status = "failed"
user_generated_image.error_message = "No image URL found in response"
db.session.commit()
raise Exception("Failed to generate image")
text_content = raw_content.get("text")
# Update the existing UserGeneratedImage with the generated content
user_generated_image.workflow_run_id = workflow_run_id
user_generated_image.image_url = url
user_generated_image.text_content = text_content
user_generated_image.raw_content = raw_content
user_generated_image.status = "completed"
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style(
f"Image generation completed for user {end_user_id}. Image ID: {image_id}. Latency: {end_at - start_at}",
fg="green",
)
)
return image_id
except Exception as e:
logging.exception(f"Failed to generate image: {str(e)}")
# Update status to failed if we have the entity
try:
user_generated_image = (
db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first()
)
if user_generated_image:
user_generated_image.status = "failed"
user_generated_image.error_message = str(e)
db.session.commit()
except Exception:
logging.exception(f"Failed to update image status for {image_id}")
raise
Loading…
Cancel
Save