impl user generate image

pull/21891/head
ytqh 1 year ago
parent 2c7c76aec3
commit d85267d5e1

@ -1,3 +1,5 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
@ -31,3 +33,18 @@ class SchoolConfig(BaseSettings):
description="App id for health summary generation.",
default="",
)
IMAGE_GENERATION_DAILY_LIMIT: int = Field(
description="Daily limit for image generation.",
default=5,
)
IMAGE_GENERATION_MIN_CONVERSATION_ROUNDS: int = Field(
description="Minimum conversation rounds for image generation.",
default=10,
)
IMAGE_GENERATION_APP_ID: Optional[str] = Field(
description="App id for image generation.",
default=None,
)

@ -4,6 +4,6 @@ from libs.external_api import ExternalApi
bp = Blueprint("service_api_with_auth", __name__, url_prefix="/service")
api = ExternalApi(bp)
from .app import app, audio, completion, conversation, file, message, workflow
from .app import app, audio, completion, conversation, file, image, message, workflow
from .auth import login
from .user import profile

@ -107,3 +107,9 @@ class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class NotEnoughMessageCountError(BaseHTTPException):
error_code = "not_enough_message_count"
description = "I need to know more about you before creating an image. Please continue our conversation."
code = 400

@ -1,29 +1,16 @@
import json
import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from configs import dify_config
from controllers.service_api_with_auth import api
from controllers.service_api_with_auth.app.error import NotChatAppError
from controllers.service_api_with_auth.app.error import NotChatAppError, NotEnoughMessageCountError
from controllers.service_api_with_auth.wraps import validate_user_token_and_extract_info
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import FileTransferMethod, FileType
from extensions.ext_database import db
from fields.end_user_fields import image_fields, image_list_fields
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from libs.helper import TimestampField, uuid_value
from models.enums import CreatedByRole
from models.model import App, AppMode, Conversation, EndUser, Message, UserGeneratedImage
from models.types import StringUUID
from fields.end_user_fields import end_user_image_fields, end_user_image_list_pagination_fields
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser, UserGeneratedImage
from services.image_generation_service import ImageGenerationService
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
# Constants
DEFAULT_DAILY_LIMIT = 5
MIN_CONVERSATION_ROUNDS = 10
from werkzeug.exceptions import InternalServerError, NotFound
class ImageGenerateApi(Resource):
@ -44,13 +31,8 @@ class ImageGenerateApi(Resource):
schema:
type: object
required:
- conversation_id
- content_type
properties:
conversation_id:
type: string
format: uuid
description: ID of the conversation to use for image generation
content_type:
type: string
enum: [self_message, summary_advice]
@ -64,6 +46,10 @@ class ImageGenerateApi(Resource):
result:
type: string
example: success
image_id:
type: string
description: ID of the generated image, futher to fetch the image details and status
example: 123e4567-e89b-12d3-a456-426614174000
message:
type: string
example: Image generation started
@ -76,90 +62,33 @@ class ImageGenerateApi(Resource):
404:
description: Conversation not found or not a chat app
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="json")
parser.add_argument(
"content_type", required=True, type=str, choices=["self_message", "summary_advice"], location="json"
)
args = parser.parse_args()
conversation_id = str(args["conversation_id"])
content_type = args["content_type"]
# Check if conversation exists
conversation = (
db.session.query(Conversation)
.filter(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.from_end_user_id == end_user.id,
)
.first()
)
if not conversation:
raise NotFound("Conversation not found")
# Check if conversation has enough rounds
messages_count = db.session.query(Message).filter(Message.conversation_id == conversation_id).count()
if messages_count < MIN_CONVERSATION_ROUNDS:
return {
"result": "error",
"message": "I need to know more about you before creating an image. Please continue our conversation.",
}, 400
# Check if user has reached daily limit
today = datetime.utcnow().date()
tomorrow = today + timedelta(days=1)
today_start = datetime.combine(today, datetime.min.time())
today_end = datetime.combine(tomorrow, datetime.min.time())
daily_count = (
db.session.query(UserGeneratedImage)
.filter(
UserGeneratedImage.end_user_id == end_user.id,
UserGeneratedImage.created_at >= today_start,
UserGeneratedImage.created_at < today_end,
)
.count()
)
if daily_count >= DEFAULT_DAILY_LIMIT:
return {
"result": "error",
"message": (
f"You've reached your daily limit of {DEFAULT_DAILY_LIMIT} generated images. Please try again tomorrow."
),
}, 403
if end_user.total_messages_count < dify_config.IMAGE_GENERATION_MIN_CONVERSATION_ROUNDS:
raise NotEnoughMessageCountError()
try:
# Use the service to generate the image
# This would typically be done asynchronously in a background task
# For simplicity, we're doing it synchronously here
image_id = ImageGenerationService.process_image_generation_request(
app_id=str(app_model.id),
conversation_id=conversation_id,
end_user_id=str(end_user.id),
image_id = ImageGenerationService.generate_image(
end_user=end_user,
content_type=content_type,
)
if image_id:
return {"result": "success", "message": "Image generated successfully.", "image_id": image_id}
else:
return {"result": "error", "message": "Failed to generate image. Please try again later."}, 500
except Exception as e:
except Exception:
raise InternalServerError("Failed to generate image")
class ImageListApi(Resource):
@validate_user_token_and_extract_info
@marshal_with(image_list_fields)
@marshal_with(end_user_image_list_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
"""Get user-generated images list.
---
@ -212,22 +141,12 @@ class ImageListApi(Resource):
limit = min(args["limit"], 100)
offset = max(args["offset"], 0)
# Get images for the user
query = (
db.session.query(UserGeneratedImage)
.filter(UserGeneratedImage.app_id == app_model.id, UserGeneratedImage.end_user_id == end_user.id)
.order_by(UserGeneratedImage.created_at.desc())
)
total_count = query.count()
images = query.limit(limit).offset(offset).all()
return {"data": images, "has_more": (offset + limit) < total_count}
return ImageGenerationService.pagination_image_list(end_user=end_user, limit=limit, offset=offset)
class ImageDetailApi(Resource):
@validate_user_token_and_extract_info
@marshal_with(image_fields)
@marshal_with(end_user_image_fields)
def get(self, app_model: App, end_user: EndUser, image_id):
"""Get a specific generated image.
---
@ -254,30 +173,11 @@ class ImageDetailApi(Resource):
404:
description: Image not found or not a chat app
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
image_id = str(image_id)
# Get the image
image = (
db.session.query(UserGeneratedImage)
.filter(
UserGeneratedImage.id == image_id,
UserGeneratedImage.app_id == app_model.id,
UserGeneratedImage.end_user_id == end_user.id,
)
.first()
)
if not image:
raise NotFound("Image not found")
return image
return ImageGenerationService.get_image_by_id(image_id=image_id)
# Register API resources
api.add_resource(ImageGenerateApi, "/images/generate")
api.add_resource(ImageListApi, "/images")
api.add_resource(ImageDetailApi, "/images/<uuid:image_id>", endpoint="image_detail")
api.add_resource(ImageDetailApi, "/images/<uuid:image_id>")

@ -28,15 +28,16 @@ end_users_infinite_scroll_pagination_fields = {
}
# Image generation fields definition
image_fields = {
end_user_image_fields = {
"id": fields.String,
"status": fields.String,
"image_url": fields.String,
"content_type": fields.String,
"text_content": fields.String,
"created_at": TimestampField,
}
image_list_fields = {
"data": fields.List(fields.Nested(image_fields)),
"has_more": fields.Boolean,
end_user_image_list_pagination_fields = {
"total": fields.Integer,
"data": fields.List(fields.Nested(end_user_image_fields)),
}

@ -0,0 +1,49 @@
"""update user gen image
Revision ID: fc481e896b3a
Revises: e4e52e0dfb56
Create Date: 2025-03-30 12:48:01.725923
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fc481e896b3a'
down_revision = 'e4e52e0dfb56'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.add_column(sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False))
batch_op.add_column(sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False))
batch_op.alter_column('image_url',
existing_type=sa.TEXT(),
nullable=True)
batch_op.alter_column('text_content',
existing_type=sa.TEXT(),
nullable=True)
batch_op.drop_column('conversation_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.add_column(sa.Column('conversation_id', sa.UUID(), autoincrement=False, nullable=False))
batch_op.alter_column('text_content',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('image_url',
existing_type=sa.TEXT(),
nullable=False)
batch_op.drop_column('updated_at')
batch_op.drop_column('workflow_run_id')
# ### end Alembic commands ###

@ -0,0 +1,33 @@
"""add raw content
Revision ID: a215b993fa58
Revises: fc481e896b3a
Create Date: 2025-03-31 20:54:49.204381
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a215b993fa58'
down_revision = 'fc481e896b3a'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.add_column(sa.Column('raw_content', sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.drop_column('raw_content')
# ### end Alembic commands ###

@ -15,7 +15,7 @@ from flask import request
from flask_login import UserMixin # type: ignore
from libs.helper import generate_string
from models.enums import CreatedByRole
from models.workflow import WorkflowRunStatus
from models.workflow import WorkflowRun, WorkflowRunStatus
from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column
@ -1455,6 +1455,15 @@ class EndUser(UserMixin, db.Model): # type: ignore[name-defined]
"memory": self.memory,
}
@property
def total_messages_count(self):
return (
db.session.query(Message)
.filter(Message.from_end_user_id == self.id)
.filter(Message.organization_id == self.organization_id)
.count()
)
class Site(db.Model): # type: ignore[name-defined]
__tablename__ = "sites"
@ -1857,16 +1866,36 @@ class UserGeneratedImage(db.Model): # type: ignore[name-defined]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False)
end_user_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(StringUUID, nullable=False)
image_url = db.Column(db.Text, nullable=False)
workflow_run_id = db.Column(StringUUID, nullable=False) # related generation id
content_type = db.Column(db.String(255), nullable=False) # 'self_message' or 'summary_advice'
text_content = db.Column(db.Text, nullable=False)
image_url = db.Column(db.Text, nullable=True)
text_content = db.Column(db.Text, nullable=True)
raw_content = db.Column(db.JSON, nullable=True) # save raw llm outputs
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
updated_at = db.Column(
db.DateTime, nullable=False, server_default=db.func.current_timestamp(), onupdate=db.func.current_timestamp()
)
@property
def end_user(self):
return db.session.query(EndUser).filter(EndUser.id == self.end_user_id).first()
@property
def conversation(self):
return db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first()
def status(self):
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
if workflow_run is None:
return WorkflowRunStatus.FAILED.value
return workflow_run.status
def refresh_status(self):
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
if workflow_run is None:
return
if workflow_run.status == WorkflowRunStatus.SUCCEEDED and self.image_url is None:
self.image_url = workflow_run.outputs.get("image_url")
self.text_content = workflow_run.outputs.get("text_content")
db.session.commit()

@ -1,169 +1,126 @@
import json
import logging
import os
import random
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
# Import the UserGeneratedImage model from our controller module
# This is a bit of a circular import, but it's the simplest solution for now
from controllers.service_api_with_auth.app.image_generate import UserGeneratedImage
from enum import Enum
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.enums import CreatedByRole
from models.model import App, Conversation, EndUser, Message, UploadFile
from sqlalchemy.orm import Session
from libs.helper import RateLimiter
from libs.infinite_scroll_pagination import MultiPagePagination
from models.model import App, EndUser, Message, UserGeneratedImage
from services.app_generate_service import AppGenerateService
# Configure logging
logger = logging.getLogger(__name__)
# define string enum for content_type
class ContentType(str, Enum):
SELF_MESSAGE = "self_message"
SUMMARY_ADVICE = "summary_advice"
class ImageGenerationService:
@staticmethod
def generate_motivational_text(conversation_id: str, content_type: str) -> str:
"""
Generate motivational text based on conversation history.
Args:
conversation_id: The ID of the conversation
content_type: Type of content to generate ('self_message' or 'summary_advice')
Returns:
str: Generated text content
"""
# In a real implementation, this would call a large language model
# Here we'll just use placeholders based on the content type
with Session(db.engine) as session:
# Get the last few messages from the conversation to understand context
messages = (
session.query(Message)
.filter(Message.conversation_id == conversation_id)
.order_by(Message.created_at.desc())
.limit(20)
.all()
)
# Reverse to get chronological order
messages.reverse()
# Extract conversation context
context = "\n".join([f"User: {msg.query}\nAI: {msg.answer}" for msg in messages])
# In production, you would pass this context to a language model
# For demonstration, we'll return placeholder text
if content_type == "self_message":
sample_messages = [
"You've got this! Take one small step today.",
"Remember your strength - you've overcome challenges before.",
"Be kind to yourself today, you deserve it.",
"Your feelings are valid, and you have the power to work through them.",
"Small progress is still progress. Celebrate your wins today.",
]
return random.choice(sample_messages)
else: # summary_advice
sample_advice = [
"Based on our conversation, I notice you tend to be hard on yourself. Try practicing self-compassion by speaking to yourself as you would to a friend.",
"I've observed that you often describe feeling overwhelmed. Breaking tasks into smaller steps might help manage these feelings better.",
"In our discussions, I noticed patterns of negative self-talk. Consider challenging these thoughts by asking 'Is this really true?' when they arise.",
"From our conversations, it seems you might benefit from more self-care routines. Even 5 minutes of mindfulness daily could make a difference.",
"You've mentioned feeling anxious in social situations. Progressive exposure to small social interactions might help build confidence over time.",
]
return random.choice(sample_advice)
generate_image_rate_limiter = RateLimiter(
prefix="generate_image_rate_limit", max_attempts=dify_config.IMAGE_GENERATION_DAILY_LIMIT, time_window=86400 * 1
)
@staticmethod
def generate_background_image() -> str:
"""
Generate or select a background image.
In a real implementation, this might call an image generation API
or select from pre-generated images.
Returns:
str: URL of the generated/selected image
"""
# In a real implementation, this would integrate with an image generation API
# or select from a pool of pre-generated images
# For this example, we'll return a placeholder
placeholder_images = [
"https://example.com/background1.jpg",
"https://example.com/background2.jpg",
"https://example.com/background3.jpg",
"https://example.com/background4.jpg",
"https://example.com/background5.jpg",
]
return random.choice(placeholder_images)
def generate_image(end_user: EndUser, content_type: ContentType) -> str:
@staticmethod
def overlay_text_on_image(image_url: str, text: str) -> str:
"""
Overlay text on the image.
if ImageGenerationService.generate_image_rate_limiter.is_rate_limited(end_user.id):
raise Exception("Image generation rate limit exceeded")
In a real implementation, this would use image processing libraries.
if dify_config.IMAGE_GENERATION_APP_ID is None:
raise Exception("Image generation app id is not set")
Args:
image_url: URL of the background image
text: Text to overlay on the image
image_generation_app_model = App.query.filter(App.id == dify_config.IMAGE_GENERATION_APP_ID).first()
if image_generation_app_model is None:
raise Exception("Image generation app model is not found")
Returns:
str: URL of the final image with text
"""
# In a real implementation, this would use image processing libraries
# like Pillow to overlay text on the image
user_profile = end_user.extra_profile
recent_messages = (
db.session.query(Message)
.filter(Message.app_id == end_user.app_id, Message.from_end_user_id == end_user.id)
.order_by(Message.created_at.desc())
.limit(10)
.all()
)
# For this example, we'll just return the same URL
# In production, you would process the image and save it to storage
return image_url
recent_messages = [f"user: {message.query}\n\nassistant: {message.answer}" for message in recent_messages]
args = {
"inputs": {
"user_profile": json.dumps(user_profile),
"recent_messages": "\n\n".join(recent_messages),
"image_type": content_type,
}
}
@staticmethod
def process_image_generation_request(
app_id: str, conversation_id: str, end_user_id: str, content_type: str
) -> Optional[str]:
"""
Process an image generation request.
Args:
app_id: The ID of the app
conversation_id: The ID of the conversation
end_user_id: The ID of the end user
content_type: Type of content to generate ('self_message' or 'summary_advice')
Returns:
Optional[str]: ID of the generated image if successful, None otherwise
"""
try:
# 1. Generate motivational text based on conversation history
text_content = ImageGenerationService.generate_motivational_text(
conversation_id=conversation_id, content_type=content_type
response = AppGenerateService.generate(
app_model=image_generation_app_model,
user=end_user,
args=args,
invoke_from=InvokeFrom.SCHEDULER,
streaming=False,
)
# 2. Generate or select a background image
image_url = ImageGenerationService.generate_background_image()
if not isinstance(response, dict):
raise Exception("Failed to generate image")
# load workflow id and save it to db for futher fetch image status
workflow_run_id = response.get("workflow_run_id")
# 3. Overlay text on the image
final_image_url = ImageGenerationService.overlay_text_on_image(image_url=image_url, text=text_content)
raw_content = response.get("data", {}).get("outputs", {})
# 4. Create and save the user generated image record
with Session(db.engine) as session:
new_image = UserGeneratedImage(
app_id=app_id,
end_user_id=end_user_id,
conversation_id=conversation_id,
image_url=final_image_url,
# parse url from response.data.outputs
image_objs = raw_content.get("files")
url = None
for image_obj in image_objs:
if image_obj.get("type") == "image":
url = image_obj.get("url")
break
if url is None:
raise Exception("Failed to generate image")
text_content = raw_content.get("text")
user_generated_image = UserGeneratedImage(
app_id=end_user.app_id,
end_user_id=end_user.id,
workflow_run_id=workflow_run_id,
content_type=content_type,
image_url=url,
text_content=text_content,
raw_content=raw_content,
)
session.add(new_image)
session.commit()
image_id = str(new_image.id)
logger.info(f"Generated image {image_id} for user {end_user_id}")
db.session.add(user_generated_image)
db.session.commit()
return image_id
return user_generated_image.id
except Exception as e:
logger.error(f"Error generating image: {str(e)}")
return None
raise Exception(f"Failed to generate image: {e}")
@staticmethod
def pagination_image_list(end_user: EndUser, limit: int, offset: int) -> MultiPagePagination:
query = (
db.session.query(UserGeneratedImage)
.filter(UserGeneratedImage.app_id == end_user.app_id, UserGeneratedImage.end_user_id == end_user.id)
.order_by(UserGeneratedImage.created_at.desc())
)
total_count = query.count()
images = query.limit(limit).offset(offset).all()
return MultiPagePagination(data=images, total=total_count)
@staticmethod
def get_image_by_id(image_id: str) -> UserGeneratedImage:
image = db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first()
if image is None:
raise Exception("Image not found")
return image

Loading…
Cancel
Save