diff --git a/api/configs/school/__init__.py b/api/configs/school/__init__.py index cf5e08dadd..e80fd72fa9 100644 --- a/api/configs/school/__init__.py +++ b/api/configs/school/__init__.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import Field from pydantic_settings import BaseSettings @@ -31,3 +33,18 @@ class SchoolConfig(BaseSettings): description="App id for health summary generation.", default="", ) + + IMAGE_GENERATION_DAILY_LIMIT: int = Field( + description="Daily limit for image generation.", + default=5, + ) + + IMAGE_GENERATION_MIN_CONVERSATION_ROUNDS: int = Field( + description="Minimum conversation rounds for image generation.", + default=10, + ) + + IMAGE_GENERATION_APP_ID: Optional[str] = Field( + description="App id for image generation.", + default=None, + ) diff --git a/api/controllers/service_api_with_auth/__init__.py b/api/controllers/service_api_with_auth/__init__.py index fd3996d781..0c19086ef8 100644 --- a/api/controllers/service_api_with_auth/__init__.py +++ b/api/controllers/service_api_with_auth/__init__.py @@ -4,6 +4,6 @@ from libs.external_api import ExternalApi bp = Blueprint("service_api_with_auth", __name__, url_prefix="/service") api = ExternalApi(bp) -from .app import app, audio, completion, conversation, file, message, workflow +from .app import app, audio, completion, conversation, file, image, message, workflow from .auth import login from .user import profile diff --git a/api/controllers/service_api_with_auth/app/error.py b/api/controllers/service_api_with_auth/app/error.py index ca91da80c1..2965ca3cdb 100644 --- a/api/controllers/service_api_with_auth/app/error.py +++ b/api/controllers/service_api_with_auth/app/error.py @@ -107,3 +107,9 @@ class UnsupportedFileTypeError(BaseHTTPException): error_code = "unsupported_file_type" description = "File type not allowed." code = 415 + + +class NotEnoughMessageCountError(BaseHTTPException): + error_code = "not_enough_message_count" + description = "I need to know more about you before creating an image. Please continue our conversation." + code = 400 diff --git a/api/controllers/service_api_with_auth/app/image_generate.py b/api/controllers/service_api_with_auth/app/image.py similarity index 51% rename from api/controllers/service_api_with_auth/app/image_generate.py rename to api/controllers/service_api_with_auth/app/image.py index 4e36954c16..026d1b89c5 100644 --- a/api/controllers/service_api_with_auth/app/image_generate.py +++ b/api/controllers/service_api_with_auth/app/image.py @@ -1,29 +1,16 @@ -import json -import logging -import os -import uuid from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Tuple +from configs import dify_config from controllers.service_api_with_auth import api -from controllers.service_api_with_auth.app.error import NotChatAppError +from controllers.service_api_with_auth.app.error import NotChatAppError, NotEnoughMessageCountError from controllers.service_api_with_auth.wraps import validate_user_token_and_extract_info -from core.app.entities.app_invoke_entities import InvokeFrom -from core.file import FileTransferMethod, FileType from extensions.ext_database import db -from fields.end_user_fields import image_fields, image_list_fields -from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore -from libs.helper import TimestampField, uuid_value -from models.enums import CreatedByRole -from models.model import App, AppMode, Conversation, EndUser, Message, UserGeneratedImage -from models.types import StringUUID +from fields.end_user_fields import end_user_image_fields, end_user_image_list_pagination_fields +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from libs.helper import uuid_value +from models.model import App, AppMode, EndUser, UserGeneratedImage from services.image_generation_service import ImageGenerationService -from sqlalchemy.orm import Session -from werkzeug.exceptions import BadRequest, InternalServerError, NotFound - -# Constants -DEFAULT_DAILY_LIMIT = 5 -MIN_CONVERSATION_ROUNDS = 10 +from werkzeug.exceptions import InternalServerError, NotFound class ImageGenerateApi(Resource): @@ -44,13 +31,8 @@ class ImageGenerateApi(Resource): schema: type: object required: - - conversation_id - content_type properties: - conversation_id: - type: string - format: uuid - description: ID of the conversation to use for image generation content_type: type: string enum: [self_message, summary_advice] @@ -64,6 +46,10 @@ class ImageGenerateApi(Resource): result: type: string example: success + image_id: + type: string + description: ID of the generated image, futher to fetch the image details and status + example: 123e4567-e89b-12d3-a456-426614174000 message: type: string example: Image generation started @@ -76,90 +62,33 @@ class ImageGenerateApi(Resource): 404: description: Conversation not found or not a chat app """ - app_mode = AppMode.value_of(app_model.mode) - if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: - raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument("conversation_id", required=True, type=uuid_value, location="json") parser.add_argument( "content_type", required=True, type=str, choices=["self_message", "summary_advice"], location="json" ) args = parser.parse_args() - conversation_id = str(args["conversation_id"]) content_type = args["content_type"] - # Check if conversation exists - conversation = ( - db.session.query(Conversation) - .filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - Conversation.from_end_user_id == end_user.id, - ) - .first() - ) - - if not conversation: - raise NotFound("Conversation not found") - - # Check if conversation has enough rounds - messages_count = db.session.query(Message).filter(Message.conversation_id == conversation_id).count() - - if messages_count < MIN_CONVERSATION_ROUNDS: - return { - "result": "error", - "message": "I need to know more about you before creating an image. Please continue our conversation.", - }, 400 - - # Check if user has reached daily limit - today = datetime.utcnow().date() - tomorrow = today + timedelta(days=1) - today_start = datetime.combine(today, datetime.min.time()) - today_end = datetime.combine(tomorrow, datetime.min.time()) - - daily_count = ( - db.session.query(UserGeneratedImage) - .filter( - UserGeneratedImage.end_user_id == end_user.id, - UserGeneratedImage.created_at >= today_start, - UserGeneratedImage.created_at < today_end, - ) - .count() - ) - - if daily_count >= DEFAULT_DAILY_LIMIT: - return { - "result": "error", - "message": ( - f"You've reached your daily limit of {DEFAULT_DAILY_LIMIT} generated images. Please try again tomorrow." - ), - }, 403 + if end_user.total_messages_count < dify_config.IMAGE_GENERATION_MIN_CONVERSATION_ROUNDS: + raise NotEnoughMessageCountError() try: # Use the service to generate the image - # This would typically be done asynchronously in a background task - # For simplicity, we're doing it synchronously here - image_id = ImageGenerationService.process_image_generation_request( - app_id=str(app_model.id), - conversation_id=conversation_id, - end_user_id=str(end_user.id), + image_id = ImageGenerationService.generate_image( + end_user=end_user, content_type=content_type, ) - if image_id: - return {"result": "success", "message": "Image generated successfully.", "image_id": image_id} - else: - return {"result": "error", "message": "Failed to generate image. Please try again later."}, 500 - - except Exception as e: + return {"result": "success", "message": "Image generated successfully.", "image_id": image_id} + except Exception: raise InternalServerError("Failed to generate image") class ImageListApi(Resource): @validate_user_token_and_extract_info - @marshal_with(image_list_fields) + @marshal_with(end_user_image_list_pagination_fields) def get(self, app_model: App, end_user: EndUser): """Get user-generated images list. --- @@ -212,22 +141,12 @@ class ImageListApi(Resource): limit = min(args["limit"], 100) offset = max(args["offset"], 0) - # Get images for the user - query = ( - db.session.query(UserGeneratedImage) - .filter(UserGeneratedImage.app_id == app_model.id, UserGeneratedImage.end_user_id == end_user.id) - .order_by(UserGeneratedImage.created_at.desc()) - ) - - total_count = query.count() - images = query.limit(limit).offset(offset).all() - - return {"data": images, "has_more": (offset + limit) < total_count} + return ImageGenerationService.pagination_image_list(end_user=end_user, limit=limit, offset=offset) class ImageDetailApi(Resource): @validate_user_token_and_extract_info - @marshal_with(image_fields) + @marshal_with(end_user_image_fields) def get(self, app_model: App, end_user: EndUser, image_id): """Get a specific generated image. --- @@ -254,30 +173,11 @@ class ImageDetailApi(Resource): 404: description: Image not found or not a chat app """ - app_mode = AppMode.value_of(app_model.mode) - if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: - raise NotChatAppError() - image_id = str(image_id) - - # Get the image - image = ( - db.session.query(UserGeneratedImage) - .filter( - UserGeneratedImage.id == image_id, - UserGeneratedImage.app_id == app_model.id, - UserGeneratedImage.end_user_id == end_user.id, - ) - .first() - ) - - if not image: - raise NotFound("Image not found") - - return image + return ImageGenerationService.get_image_by_id(image_id=image_id) # Register API resources api.add_resource(ImageGenerateApi, "/images/generate") api.add_resource(ImageListApi, "/images") -api.add_resource(ImageDetailApi, "/images/", endpoint="image_detail") +api.add_resource(ImageDetailApi, "/images/") diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index eb03946919..0333d2435d 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -28,15 +28,16 @@ end_users_infinite_scroll_pagination_fields = { } # Image generation fields definition -image_fields = { +end_user_image_fields = { "id": fields.String, + "status": fields.String, "image_url": fields.String, "content_type": fields.String, "text_content": fields.String, "created_at": TimestampField, } -image_list_fields = { - "data": fields.List(fields.Nested(image_fields)), - "has_more": fields.Boolean, +end_user_image_list_pagination_fields = { + "total": fields.Integer, + "data": fields.List(fields.Nested(end_user_image_fields)), } diff --git a/api/migrations/versions/2025_03_30_1248-fc481e896b3a_update_user_gen_image.py b/api/migrations/versions/2025_03_30_1248-fc481e896b3a_update_user_gen_image.py new file mode 100644 index 0000000000..ddbf2bba62 --- /dev/null +++ b/api/migrations/versions/2025_03_30_1248-fc481e896b3a_update_user_gen_image.py @@ -0,0 +1,49 @@ +"""update user gen image + +Revision ID: fc481e896b3a +Revises: e4e52e0dfb56 +Create Date: 2025-03-30 12:48:01.725923 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'fc481e896b3a' +down_revision = 'e4e52e0dfb56' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('user_generated_images', schema=None) as batch_op: + batch_op.add_column(sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False)) + batch_op.add_column(sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False)) + batch_op.alter_column('image_url', + existing_type=sa.TEXT(), + nullable=True) + batch_op.alter_column('text_content', + existing_type=sa.TEXT(), + nullable=True) + batch_op.drop_column('conversation_id') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('user_generated_images', schema=None) as batch_op: + batch_op.add_column(sa.Column('conversation_id', sa.UUID(), autoincrement=False, nullable=False)) + batch_op.alter_column('text_content', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('image_url', + existing_type=sa.TEXT(), + nullable=False) + batch_op.drop_column('updated_at') + batch_op.drop_column('workflow_run_id') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_03_31_2054-a215b993fa58_add_raw_content.py b/api/migrations/versions/2025_03_31_2054-a215b993fa58_add_raw_content.py new file mode 100644 index 0000000000..184672b3fa --- /dev/null +++ b/api/migrations/versions/2025_03_31_2054-a215b993fa58_add_raw_content.py @@ -0,0 +1,33 @@ +"""add raw content + +Revision ID: a215b993fa58 +Revises: fc481e896b3a +Create Date: 2025-03-31 20:54:49.204381 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'a215b993fa58' +down_revision = 'fc481e896b3a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('user_generated_images', schema=None) as batch_op: + batch_op.add_column(sa.Column('raw_content', sa.JSON(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('user_generated_images', schema=None) as batch_op: + batch_op.drop_column('raw_content') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index c4dfe632d5..d824c6db08 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -15,7 +15,7 @@ from flask import request from flask_login import UserMixin # type: ignore from libs.helper import generate_string from models.enums import CreatedByRole -from models.workflow import WorkflowRunStatus +from models.workflow import WorkflowRun, WorkflowRunStatus from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column @@ -1455,6 +1455,15 @@ class EndUser(UserMixin, db.Model): # type: ignore[name-defined] "memory": self.memory, } + @property + def total_messages_count(self): + return ( + db.session.query(Message) + .filter(Message.from_end_user_id == self.id) + .filter(Message.organization_id == self.organization_id) + .count() + ) + class Site(db.Model): # type: ignore[name-defined] __tablename__ = "sites" @@ -1857,16 +1866,36 @@ class UserGeneratedImage(db.Model): # type: ignore[name-defined] id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) end_user_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, nullable=False) - image_url = db.Column(db.Text, nullable=False) + workflow_run_id = db.Column(StringUUID, nullable=False) # related generation id content_type = db.Column(db.String(255), nullable=False) # 'self_message' or 'summary_advice' - text_content = db.Column(db.Text, nullable=False) + image_url = db.Column(db.Text, nullable=True) + text_content = db.Column(db.Text, nullable=True) + raw_content = db.Column(db.JSON, nullable=True) # save raw llm outputs created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=db.func.current_timestamp(), onupdate=db.func.current_timestamp() + ) @property def end_user(self): return db.session.query(EndUser).filter(EndUser.id == self.end_user_id).first() @property - def conversation(self): - return db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() + def status(self): + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() + + if workflow_run is None: + return WorkflowRunStatus.FAILED.value + + return workflow_run.status + + def refresh_status(self): + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() + + if workflow_run is None: + return + + if workflow_run.status == WorkflowRunStatus.SUCCEEDED and self.image_url is None: + self.image_url = workflow_run.outputs.get("image_url") + self.text_content = workflow_run.outputs.get("text_content") + db.session.commit() diff --git a/api/services/image_generation_service.py b/api/services/image_generation_service.py index 1ee60e894a..cf39d9cd5c 100644 --- a/api/services/image_generation_service.py +++ b/api/services/image_generation_service.py @@ -1,169 +1,126 @@ import json -import logging -import os -import random -from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, Union - -# Import the UserGeneratedImage model from our controller module -# This is a bit of a circular import, but it's the simplest solution for now -from controllers.service_api_with_auth.app.image_generate import UserGeneratedImage +from enum import Enum + +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db -from models.enums import CreatedByRole -from models.model import App, Conversation, EndUser, Message, UploadFile -from sqlalchemy.orm import Session +from libs.helper import RateLimiter +from libs.infinite_scroll_pagination import MultiPagePagination +from models.model import App, EndUser, Message, UserGeneratedImage +from services.app_generate_service import AppGenerateService + -# Configure logging -logger = logging.getLogger(__name__) +# define string enum for content_type +class ContentType(str, Enum): + SELF_MESSAGE = "self_message" + SUMMARY_ADVICE = "summary_advice" class ImageGenerationService: - @staticmethod - def generate_motivational_text(conversation_id: str, content_type: str) -> str: - """ - Generate motivational text based on conversation history. - - Args: - conversation_id: The ID of the conversation - content_type: Type of content to generate ('self_message' or 'summary_advice') - - Returns: - str: Generated text content - """ - # In a real implementation, this would call a large language model - # Here we'll just use placeholders based on the content type - - with Session(db.engine) as session: - # Get the last few messages from the conversation to understand context - messages = ( - session.query(Message) - .filter(Message.conversation_id == conversation_id) - .order_by(Message.created_at.desc()) - .limit(20) - .all() - ) - # Reverse to get chronological order - messages.reverse() - - # Extract conversation context - context = "\n".join([f"User: {msg.query}\nAI: {msg.answer}" for msg in messages]) - - # In production, you would pass this context to a language model - # For demonstration, we'll return placeholder text - - if content_type == "self_message": - sample_messages = [ - "You've got this! Take one small step today.", - "Remember your strength - you've overcome challenges before.", - "Be kind to yourself today, you deserve it.", - "Your feelings are valid, and you have the power to work through them.", - "Small progress is still progress. Celebrate your wins today.", - ] - return random.choice(sample_messages) - else: # summary_advice - sample_advice = [ - "Based on our conversation, I notice you tend to be hard on yourself. Try practicing self-compassion by speaking to yourself as you would to a friend.", - "I've observed that you often describe feeling overwhelmed. Breaking tasks into smaller steps might help manage these feelings better.", - "In our discussions, I noticed patterns of negative self-talk. Consider challenging these thoughts by asking 'Is this really true?' when they arise.", - "From our conversations, it seems you might benefit from more self-care routines. Even 5 minutes of mindfulness daily could make a difference.", - "You've mentioned feeling anxious in social situations. Progressive exposure to small social interactions might help build confidence over time.", - ] - return random.choice(sample_advice) + generate_image_rate_limiter = RateLimiter( + prefix="generate_image_rate_limit", max_attempts=dify_config.IMAGE_GENERATION_DAILY_LIMIT, time_window=86400 * 1 + ) @staticmethod - def generate_background_image() -> str: - """ - Generate or select a background image. - - In a real implementation, this might call an image generation API - or select from pre-generated images. - - Returns: - str: URL of the generated/selected image - """ - # In a real implementation, this would integrate with an image generation API - # or select from a pool of pre-generated images - - # For this example, we'll return a placeholder - placeholder_images = [ - "https://example.com/background1.jpg", - "https://example.com/background2.jpg", - "https://example.com/background3.jpg", - "https://example.com/background4.jpg", - "https://example.com/background5.jpg", - ] - - return random.choice(placeholder_images) + def generate_image(end_user: EndUser, content_type: ContentType) -> str: - @staticmethod - def overlay_text_on_image(image_url: str, text: str) -> str: - """ - Overlay text on the image. + if ImageGenerationService.generate_image_rate_limiter.is_rate_limited(end_user.id): + raise Exception("Image generation rate limit exceeded") - In a real implementation, this would use image processing libraries. + if dify_config.IMAGE_GENERATION_APP_ID is None: + raise Exception("Image generation app id is not set") - Args: - image_url: URL of the background image - text: Text to overlay on the image + image_generation_app_model = App.query.filter(App.id == dify_config.IMAGE_GENERATION_APP_ID).first() + if image_generation_app_model is None: + raise Exception("Image generation app model is not found") - Returns: - str: URL of the final image with text - """ - # In a real implementation, this would use image processing libraries - # like Pillow to overlay text on the image + user_profile = end_user.extra_profile + recent_messages = ( + db.session.query(Message) + .filter(Message.app_id == end_user.app_id, Message.from_end_user_id == end_user.id) + .order_by(Message.created_at.desc()) + .limit(10) + .all() + ) - # For this example, we'll just return the same URL - # In production, you would process the image and save it to storage - return image_url + recent_messages = [f"user: {message.query}\n\nassistant: {message.answer}" for message in recent_messages] + + args = { + "inputs": { + "user_profile": json.dumps(user_profile), + "recent_messages": "\n\n".join(recent_messages), + "image_type": content_type, + } + } - @staticmethod - def process_image_generation_request( - app_id: str, conversation_id: str, end_user_id: str, content_type: str - ) -> Optional[str]: - """ - Process an image generation request. - - Args: - app_id: The ID of the app - conversation_id: The ID of the conversation - end_user_id: The ID of the end user - content_type: Type of content to generate ('self_message' or 'summary_advice') - - Returns: - Optional[str]: ID of the generated image if successful, None otherwise - """ try: - # 1. Generate motivational text based on conversation history - text_content = ImageGenerationService.generate_motivational_text( - conversation_id=conversation_id, content_type=content_type + response = AppGenerateService.generate( + app_model=image_generation_app_model, + user=end_user, + args=args, + invoke_from=InvokeFrom.SCHEDULER, + streaming=False, ) - # 2. Generate or select a background image - image_url = ImageGenerationService.generate_background_image() + if not isinstance(response, dict): + raise Exception("Failed to generate image") + + # load workflow id and save it to db for futher fetch image status + workflow_run_id = response.get("workflow_run_id") + + raw_content = response.get("data", {}).get("outputs", {}) + + # parse url from response.data.outputs + image_objs = raw_content.get("files") + url = None + for image_obj in image_objs: + if image_obj.get("type") == "image": + url = image_obj.get("url") + break - # 3. Overlay text on the image - final_image_url = ImageGenerationService.overlay_text_on_image(image_url=image_url, text=text_content) + if url is None: + raise Exception("Failed to generate image") - # 4. Create and save the user generated image record - with Session(db.engine) as session: - new_image = UserGeneratedImage( - app_id=app_id, - end_user_id=end_user_id, - conversation_id=conversation_id, - image_url=final_image_url, - content_type=content_type, - text_content=text_content, - ) + text_content = raw_content.get("text") - session.add(new_image) - session.commit() + user_generated_image = UserGeneratedImage( + app_id=end_user.app_id, + end_user_id=end_user.id, + workflow_run_id=workflow_run_id, + content_type=content_type, + image_url=url, + text_content=text_content, + raw_content=raw_content, + ) - image_id = str(new_image.id) - logger.info(f"Generated image {image_id} for user {end_user_id}") + db.session.add(user_generated_image) + db.session.commit() - return image_id + return user_generated_image.id except Exception as e: - logger.error(f"Error generating image: {str(e)}") - return None + raise Exception(f"Failed to generate image: {e}") + + @staticmethod + def pagination_image_list(end_user: EndUser, limit: int, offset: int) -> MultiPagePagination: + + query = ( + db.session.query(UserGeneratedImage) + .filter(UserGeneratedImage.app_id == end_user.app_id, UserGeneratedImage.end_user_id == end_user.id) + .order_by(UserGeneratedImage.created_at.desc()) + ) + + total_count = query.count() + images = query.limit(limit).offset(offset).all() + + return MultiPagePagination(data=images, total=total_count) + + @staticmethod + def get_image_by_id(image_id: str) -> UserGeneratedImage: + image = db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first() + + if image is None: + raise Exception("Image not found") + + return image