diff --git a/api/controllers/service_api_with_auth/app/image.py b/api/controllers/service_api_with_auth/app/image.py index 026d1b89c5..3e74922779 100644 --- a/api/controllers/service_api_with_auth/app/image.py +++ b/api/controllers/service_api_with_auth/app/image.py @@ -1,5 +1,3 @@ -from datetime import datetime, timedelta - from configs import dify_config from controllers.service_api_with_auth import api from controllers.service_api_with_auth.app.error import NotChatAppError, NotEnoughMessageCountError @@ -48,7 +46,7 @@ class ImageGenerateApi(Resource): example: success image_id: type: string - description: ID of the generated image, futher to fetch the image details and status + description: ID of the generated image (initially in pending state) example: 123e4567-e89b-12d3-a456-426614174000 message: type: string @@ -75,15 +73,20 @@ class ImageGenerateApi(Resource): raise NotEnoughMessageCountError() try: - # Use the service to generate the image + # Create a pending image entity and start the generation task image_id = ImageGenerationService.generate_image( end_user=end_user, content_type=content_type, ) - return {"result": "success", "message": "Image generated successfully.", "image_id": image_id} - except Exception: - raise InternalServerError("Failed to generate image") + # Return the image ID for status checking + return { + "result": "success", + "message": "Image generation started. Check status with the image ID.", + "image_id": image_id, + } + except Exception as e: + raise InternalServerError(f"Failed to generate image: {e}") class ImageListApi(Resource): diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 0333d2435d..16e5d52921 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -34,6 +34,7 @@ end_user_image_fields = { "image_url": fields.String, "content_type": fields.String, "text_content": fields.String, + "error_message": fields.String, "created_at": TimestampField, } diff --git a/api/migrations/versions/2025_03_31_2112-329938947502_update_image_table.py b/api/migrations/versions/2025_03_31_2112-329938947502_update_image_table.py new file mode 100644 index 0000000000..fece3f6e07 --- /dev/null +++ b/api/migrations/versions/2025_03_31_2112-329938947502_update_image_table.py @@ -0,0 +1,39 @@ +"""update image table + +Revision ID: 329938947502 +Revises: a215b993fa58 +Create Date: 2025-03-31 21:12:55.412130 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '329938947502' +down_revision = 'a215b993fa58' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('user_generated_images', schema=None) as batch_op: + batch_op.add_column(sa.Column('error_message', sa.Text(), nullable=True)) + batch_op.alter_column('workflow_run_id', + existing_type=sa.UUID(), + nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('user_generated_images', schema=None) as batch_op: + batch_op.alter_column('workflow_run_id', + existing_type=sa.UUID(), + nullable=False) + batch_op.drop_column('error_message') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_03_31_2115-382f0cad6652_update_status.py b/api/migrations/versions/2025_03_31_2115-382f0cad6652_update_status.py new file mode 100644 index 0000000000..82ff145278 --- /dev/null +++ b/api/migrations/versions/2025_03_31_2115-382f0cad6652_update_status.py @@ -0,0 +1,33 @@ +"""update status + +Revision ID: 382f0cad6652 +Revises: 329938947502 +Create Date: 2025-03-31 21:15:32.929703 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '382f0cad6652' +down_revision = '329938947502' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('user_generated_images', schema=None) as batch_op: + batch_op.add_column(sa.Column('status', sa.String(length=20), server_default='pending', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('user_generated_images', schema=None) as batch_op: + batch_op.drop_column('status') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index d824c6db08..24a0f9bc48 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1866,11 +1866,15 @@ class UserGeneratedImage(db.Model): # type: ignore[name-defined] id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) end_user_id = db.Column(StringUUID, nullable=False) - workflow_run_id = db.Column(StringUUID, nullable=False) # related generation id + workflow_run_id = db.Column(StringUUID, nullable=True) # related generation id (nullable for pending status) content_type = db.Column(db.String(255), nullable=False) # 'self_message' or 'summary_advice' image_url = db.Column(db.Text, nullable=True) text_content = db.Column(db.Text, nullable=True) raw_content = db.Column(db.JSON, nullable=True) # save raw llm outputs + status = db.Column( + db.String(20), nullable=False, server_default="pending" + ) # pending, processing, completed, failed + error_message = db.Column(db.Text, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) updated_at = db.Column( db.DateTime, nullable=False, server_default=db.func.current_timestamp(), onupdate=db.func.current_timestamp() @@ -1879,23 +1883,3 @@ class UserGeneratedImage(db.Model): # type: ignore[name-defined] @property def end_user(self): return db.session.query(EndUser).filter(EndUser.id == self.end_user_id).first() - - @property - def status(self): - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() - - if workflow_run is None: - return WorkflowRunStatus.FAILED.value - - return workflow_run.status - - def refresh_status(self): - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() - - if workflow_run is None: - return - - if workflow_run.status == WorkflowRunStatus.SUCCEEDED and self.image_url is None: - self.image_url = workflow_run.outputs.get("image_url") - self.text_content = workflow_run.outputs.get("text_content") - db.session.commit() diff --git a/api/services/image_generation_service.py b/api/services/image_generation_service.py index cf39d9cd5c..66edf5c172 100644 --- a/api/services/image_generation_service.py +++ b/api/services/image_generation_service.py @@ -8,6 +8,7 @@ from libs.helper import RateLimiter from libs.infinite_scroll_pagination import MultiPagePagination from models.model import App, EndUser, Message, UserGeneratedImage from services.app_generate_service import AppGenerateService +from tasks.image_generation_task import generate_image_task # define string enum for content_type @@ -24,83 +25,39 @@ class ImageGenerationService: @staticmethod def generate_image(end_user: EndUser, content_type: ContentType) -> str: + """ + Initiates asynchronous image generation process and creates a pending image record + Args: + end_user: End user object + content_type: Type of content to generate + + Returns: + The ID of the created UserGeneratedImage entity that will be updated by the task + """ + # Check if rate limited before submitting task if ImageGenerationService.generate_image_rate_limiter.is_rate_limited(end_user.id): raise Exception("Image generation rate limit exceeded") - if dify_config.IMAGE_GENERATION_APP_ID is None: - raise Exception("Image generation app id is not set") + # Create a pending UserGeneratedImage entity + user_generated_image = UserGeneratedImage( + app_id=end_user.app_id, + end_user_id=end_user.id, + content_type=content_type, + status="pending", # Set initial status to pending + ) + + db.session.add(user_generated_image) + db.session.commit() - image_generation_app_model = App.query.filter(App.id == dify_config.IMAGE_GENERATION_APP_ID).first() - if image_generation_app_model is None: - raise Exception("Image generation app model is not found") + # Get the generated ID for tracking + image_id = str(user_generated_image.id) - user_profile = end_user.extra_profile - recent_messages = ( - db.session.query(Message) - .filter(Message.app_id == end_user.app_id, Message.from_end_user_id == end_user.id) - .order_by(Message.created_at.desc()) - .limit(10) - .all() - ) + # Submit the task asynchronously with the image_id + generate_image_task.delay(end_user_id=str(end_user.id), content_type=content_type, image_id=image_id) - recent_messages = [f"user: {message.query}\n\nassistant: {message.answer}" for message in recent_messages] - - args = { - "inputs": { - "user_profile": json.dumps(user_profile), - "recent_messages": "\n\n".join(recent_messages), - "image_type": content_type, - } - } - - try: - response = AppGenerateService.generate( - app_model=image_generation_app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SCHEDULER, - streaming=False, - ) - - if not isinstance(response, dict): - raise Exception("Failed to generate image") - - # load workflow id and save it to db for futher fetch image status - workflow_run_id = response.get("workflow_run_id") - - raw_content = response.get("data", {}).get("outputs", {}) - - # parse url from response.data.outputs - image_objs = raw_content.get("files") - url = None - for image_obj in image_objs: - if image_obj.get("type") == "image": - url = image_obj.get("url") - break - - if url is None: - raise Exception("Failed to generate image") - - text_content = raw_content.get("text") - - user_generated_image = UserGeneratedImage( - app_id=end_user.app_id, - end_user_id=end_user.id, - workflow_run_id=workflow_run_id, - content_type=content_type, - image_url=url, - text_content=text_content, - raw_content=raw_content, - ) - - db.session.add(user_generated_image) - db.session.commit() - - return user_generated_image.id - - except Exception as e: - raise Exception(f"Failed to generate image: {e}") + # Return the image ID as a reference for status checking + return image_id @staticmethod def pagination_image_list(end_user: EndUser, limit: int, offset: int) -> MultiPagePagination: diff --git a/api/tasks/image_generation_task.py b/api/tasks/image_generation_task.py new file mode 100644 index 0000000000..1e822b9005 --- /dev/null +++ b/api/tasks/image_generation_task.py @@ -0,0 +1,152 @@ +import json +import logging +import time + +import click +from celery import shared_task # type: ignore +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from models.model import App, EndUser, Message, UserGeneratedImage +from services.app_generate_service import AppGenerateService + + +@shared_task(queue="generation") +def generate_image_task( + end_user_id: str, + content_type: str, + image_id: str, +) -> str: + """ + Asynchronously generate an image based on the end user's conversation data and update the existing UserGeneratedImage record + + Args: + end_user_id: End user ID + content_type: Type of content to generate (self_message or summary_advice) + app_id: The app ID of the end user + image_id: ID of the existing pending UserGeneratedImage entity to update + + Returns: + The ID of the updated image record + + Usage: generate_image_task.delay(end_user_id, content_type, app_id, image_id) + """ + logging.info(click.style(f"Starting image generation for user {end_user_id}, image_id: {image_id}", fg="green")) + start_at = time.perf_counter() + + try: + # Retrieve models for processing + end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() + if not end_user: + raise Exception(f"End user {end_user_id} not found") + + # Get the existing UserGeneratedImage entity + user_generated_image = db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first() + if not user_generated_image: + raise Exception(f"UserGeneratedImage {image_id} not found") + + # Update status to processing + user_generated_image.status = "processing" + db.session.commit() + + # Get image generation app + if dify_config.IMAGE_GENERATION_APP_ID is None: + user_generated_image.status = "failed" + user_generated_image.error_message = "Image generation app id is not set" + db.session.commit() + raise Exception("Image generation app id is not set") + + image_generation_app_model = App.query.filter(App.id == dify_config.IMAGE_GENERATION_APP_ID).first() + if image_generation_app_model is None: + user_generated_image.status = "failed" + user_generated_image.error_message = "Image generation app model is not found" + db.session.commit() + raise Exception("Image generation app model is not found") + + # Get user profile and recent messages + user_profile = end_user.extra_profile + recent_messages = ( + db.session.query(Message) + .filter(Message.app_id == end_user.app_id, Message.from_end_user_id == end_user.id) + .order_by(Message.created_at.desc()) + .limit(10) + .all() + ) + + recent_messages = [f"user: {message.query}\n\nassistant: {message.answer}" for message in recent_messages] + + # Prepare arguments for generation + args = { + "inputs": { + "user_profile": json.dumps(user_profile), + "recent_messages": "\n\n".join(recent_messages), + "image_type": content_type, + } + } + + # Generate image through app service + response = AppGenerateService.generate( + app_model=image_generation_app_model, + user=end_user, + args=args, + invoke_from=InvokeFrom.SCHEDULER, + streaming=False, + ) + + if not isinstance(response, dict): + user_generated_image.status = "failed" + user_generated_image.error_message = "Failed to generate image" + db.session.commit() + raise Exception("Failed to generate image") + + # Extract workflow run ID and content + workflow_run_id = response.get("workflow_run_id") + raw_content = response.get("data", {}).get("outputs", {}) + + # Parse URL from response + image_objs = raw_content.get("files") + url = None + for image_obj in image_objs: + if image_obj.get("type") == "image": + url = image_obj.get("url") + break + + if url is None: + user_generated_image.status = "failed" + user_generated_image.error_message = "No image URL found in response" + db.session.commit() + raise Exception("Failed to generate image") + + text_content = raw_content.get("text") + + # Update the existing UserGeneratedImage with the generated content + user_generated_image.workflow_run_id = workflow_run_id + user_generated_image.image_url = url + user_generated_image.text_content = text_content + user_generated_image.raw_content = raw_content + user_generated_image.status = "completed" + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + f"Image generation completed for user {end_user_id}. Image ID: {image_id}. Latency: {end_at - start_at}", + fg="green", + ) + ) + + return image_id + except Exception as e: + logging.exception(f"Failed to generate image: {str(e)}") + # Update status to failed if we have the entity + try: + user_generated_image = ( + db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first() + ) + if user_generated_image: + user_generated_image.status = "failed" + user_generated_image.error_message = str(e) + db.session.commit() + except Exception: + logging.exception(f"Failed to update image status for {image_id}") + raise