impl user generate image

pull/21891/head
ytqh 1 year ago
parent 2c7c76aec3
commit d85267d5e1

@ -1,3 +1,5 @@
from typing import Optional
from pydantic import Field from pydantic import Field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -31,3 +33,18 @@ class SchoolConfig(BaseSettings):
description="App id for health summary generation.", description="App id for health summary generation.",
default="", default="",
) )
IMAGE_GENERATION_DAILY_LIMIT: int = Field(
description="Daily limit for image generation.",
default=5,
)
IMAGE_GENERATION_MIN_CONVERSATION_ROUNDS: int = Field(
description="Minimum conversation rounds for image generation.",
default=10,
)
IMAGE_GENERATION_APP_ID: Optional[str] = Field(
description="App id for image generation.",
default=None,
)

@ -4,6 +4,6 @@ from libs.external_api import ExternalApi
bp = Blueprint("service_api_with_auth", __name__, url_prefix="/service") bp = Blueprint("service_api_with_auth", __name__, url_prefix="/service")
api = ExternalApi(bp) api = ExternalApi(bp)
from .app import app, audio, completion, conversation, file, message, workflow from .app import app, audio, completion, conversation, file, image, message, workflow
from .auth import login from .auth import login
from .user import profile from .user import profile

@ -107,3 +107,9 @@ class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type" error_code = "unsupported_file_type"
description = "File type not allowed." description = "File type not allowed."
code = 415 code = 415
class NotEnoughMessageCountError(BaseHTTPException):
error_code = "not_enough_message_count"
description = "I need to know more about you before creating an image. Please continue our conversation."
code = 400

@ -1,29 +1,16 @@
import json
import logging
import os
import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from configs import dify_config
from controllers.service_api_with_auth import api from controllers.service_api_with_auth import api
from controllers.service_api_with_auth.app.error import NotChatAppError from controllers.service_api_with_auth.app.error import NotChatAppError, NotEnoughMessageCountError
from controllers.service_api_with_auth.wraps import validate_user_token_and_extract_info from controllers.service_api_with_auth.wraps import validate_user_token_and_extract_info
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import FileTransferMethod, FileType
from extensions.ext_database import db from extensions.ext_database import db
from fields.end_user_fields import image_fields, image_list_fields from fields.end_user_fields import end_user_image_fields, end_user_image_list_pagination_fields
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from flask_restful import Resource, marshal_with, reqparse # type: ignore
from libs.helper import TimestampField, uuid_value from libs.helper import uuid_value
from models.enums import CreatedByRole from models.model import App, AppMode, EndUser, UserGeneratedImage
from models.model import App, AppMode, Conversation, EndUser, Message, UserGeneratedImage
from models.types import StringUUID
from services.image_generation_service import ImageGenerationService from services.image_generation_service import ImageGenerationService
from sqlalchemy.orm import Session from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
# Constants
DEFAULT_DAILY_LIMIT = 5
MIN_CONVERSATION_ROUNDS = 10
class ImageGenerateApi(Resource): class ImageGenerateApi(Resource):
@ -44,13 +31,8 @@ class ImageGenerateApi(Resource):
schema: schema:
type: object type: object
required: required:
- conversation_id
- content_type - content_type
properties: properties:
conversation_id:
type: string
format: uuid
description: ID of the conversation to use for image generation
content_type: content_type:
type: string type: string
enum: [self_message, summary_advice] enum: [self_message, summary_advice]
@ -64,6 +46,10 @@ class ImageGenerateApi(Resource):
result: result:
type: string type: string
example: success example: success
image_id:
type: string
description: ID of the generated image, futher to fetch the image details and status
example: 123e4567-e89b-12d3-a456-426614174000
message: message:
type: string type: string
example: Image generation started example: Image generation started
@ -76,90 +62,33 @@ class ImageGenerateApi(Resource):
404: 404:
description: Conversation not found or not a chat app description: Conversation not found or not a chat app
""" """
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="json")
parser.add_argument( parser.add_argument(
"content_type", required=True, type=str, choices=["self_message", "summary_advice"], location="json" "content_type", required=True, type=str, choices=["self_message", "summary_advice"], location="json"
) )
args = parser.parse_args() args = parser.parse_args()
conversation_id = str(args["conversation_id"])
content_type = args["content_type"] content_type = args["content_type"]
# Check if conversation exists if end_user.total_messages_count < dify_config.IMAGE_GENERATION_MIN_CONVERSATION_ROUNDS:
conversation = ( raise NotEnoughMessageCountError()
db.session.query(Conversation)
.filter(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.from_end_user_id == end_user.id,
)
.first()
)
if not conversation:
raise NotFound("Conversation not found")
# Check if conversation has enough rounds
messages_count = db.session.query(Message).filter(Message.conversation_id == conversation_id).count()
if messages_count < MIN_CONVERSATION_ROUNDS:
return {
"result": "error",
"message": "I need to know more about you before creating an image. Please continue our conversation.",
}, 400
# Check if user has reached daily limit
today = datetime.utcnow().date()
tomorrow = today + timedelta(days=1)
today_start = datetime.combine(today, datetime.min.time())
today_end = datetime.combine(tomorrow, datetime.min.time())
daily_count = (
db.session.query(UserGeneratedImage)
.filter(
UserGeneratedImage.end_user_id == end_user.id,
UserGeneratedImage.created_at >= today_start,
UserGeneratedImage.created_at < today_end,
)
.count()
)
if daily_count >= DEFAULT_DAILY_LIMIT:
return {
"result": "error",
"message": (
f"You've reached your daily limit of {DEFAULT_DAILY_LIMIT} generated images. Please try again tomorrow."
),
}, 403
try: try:
# Use the service to generate the image # Use the service to generate the image
# This would typically be done asynchronously in a background task image_id = ImageGenerationService.generate_image(
# For simplicity, we're doing it synchronously here end_user=end_user,
image_id = ImageGenerationService.process_image_generation_request(
app_id=str(app_model.id),
conversation_id=conversation_id,
end_user_id=str(end_user.id),
content_type=content_type, content_type=content_type,
) )
if image_id: return {"result": "success", "message": "Image generated successfully.", "image_id": image_id}
return {"result": "success", "message": "Image generated successfully.", "image_id": image_id} except Exception:
else:
return {"result": "error", "message": "Failed to generate image. Please try again later."}, 500
except Exception as e:
raise InternalServerError("Failed to generate image") raise InternalServerError("Failed to generate image")
class ImageListApi(Resource): class ImageListApi(Resource):
@validate_user_token_and_extract_info @validate_user_token_and_extract_info
@marshal_with(image_list_fields) @marshal_with(end_user_image_list_pagination_fields)
def get(self, app_model: App, end_user: EndUser): def get(self, app_model: App, end_user: EndUser):
"""Get user-generated images list. """Get user-generated images list.
--- ---
@ -212,22 +141,12 @@ class ImageListApi(Resource):
limit = min(args["limit"], 100) limit = min(args["limit"], 100)
offset = max(args["offset"], 0) offset = max(args["offset"], 0)
# Get images for the user return ImageGenerationService.pagination_image_list(end_user=end_user, limit=limit, offset=offset)
query = (
db.session.query(UserGeneratedImage)
.filter(UserGeneratedImage.app_id == app_model.id, UserGeneratedImage.end_user_id == end_user.id)
.order_by(UserGeneratedImage.created_at.desc())
)
total_count = query.count()
images = query.limit(limit).offset(offset).all()
return {"data": images, "has_more": (offset + limit) < total_count}
class ImageDetailApi(Resource): class ImageDetailApi(Resource):
@validate_user_token_and_extract_info @validate_user_token_and_extract_info
@marshal_with(image_fields) @marshal_with(end_user_image_fields)
def get(self, app_model: App, end_user: EndUser, image_id): def get(self, app_model: App, end_user: EndUser, image_id):
"""Get a specific generated image. """Get a specific generated image.
--- ---
@ -254,30 +173,11 @@ class ImageDetailApi(Resource):
404: 404:
description: Image not found or not a chat app description: Image not found or not a chat app
""" """
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
image_id = str(image_id) image_id = str(image_id)
return ImageGenerationService.get_image_by_id(image_id=image_id)
# Get the image
image = (
db.session.query(UserGeneratedImage)
.filter(
UserGeneratedImage.id == image_id,
UserGeneratedImage.app_id == app_model.id,
UserGeneratedImage.end_user_id == end_user.id,
)
.first()
)
if not image:
raise NotFound("Image not found")
return image
# Register API resources # Register API resources
api.add_resource(ImageGenerateApi, "/images/generate") api.add_resource(ImageGenerateApi, "/images/generate")
api.add_resource(ImageListApi, "/images") api.add_resource(ImageListApi, "/images")
api.add_resource(ImageDetailApi, "/images/<uuid:image_id>", endpoint="image_detail") api.add_resource(ImageDetailApi, "/images/<uuid:image_id>")

@ -28,15 +28,16 @@ end_users_infinite_scroll_pagination_fields = {
} }
# Image generation fields definition # Image generation fields definition
image_fields = { end_user_image_fields = {
"id": fields.String, "id": fields.String,
"status": fields.String,
"image_url": fields.String, "image_url": fields.String,
"content_type": fields.String, "content_type": fields.String,
"text_content": fields.String, "text_content": fields.String,
"created_at": TimestampField, "created_at": TimestampField,
} }
image_list_fields = { end_user_image_list_pagination_fields = {
"data": fields.List(fields.Nested(image_fields)), "total": fields.Integer,
"has_more": fields.Boolean, "data": fields.List(fields.Nested(end_user_image_fields)),
} }

@ -0,0 +1,49 @@
"""update user gen image
Revision ID: fc481e896b3a
Revises: e4e52e0dfb56
Create Date: 2025-03-30 12:48:01.725923
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fc481e896b3a'
down_revision = 'e4e52e0dfb56'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.add_column(sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False))
batch_op.add_column(sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False))
batch_op.alter_column('image_url',
existing_type=sa.TEXT(),
nullable=True)
batch_op.alter_column('text_content',
existing_type=sa.TEXT(),
nullable=True)
batch_op.drop_column('conversation_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.add_column(sa.Column('conversation_id', sa.UUID(), autoincrement=False, nullable=False))
batch_op.alter_column('text_content',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('image_url',
existing_type=sa.TEXT(),
nullable=False)
batch_op.drop_column('updated_at')
batch_op.drop_column('workflow_run_id')
# ### end Alembic commands ###

@ -0,0 +1,33 @@
"""add raw content
Revision ID: a215b993fa58
Revises: fc481e896b3a
Create Date: 2025-03-31 20:54:49.204381
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a215b993fa58'
down_revision = 'fc481e896b3a'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.add_column(sa.Column('raw_content', sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('user_generated_images', schema=None) as batch_op:
batch_op.drop_column('raw_content')
# ### end Alembic commands ###

@ -15,7 +15,7 @@ from flask import request
from flask_login import UserMixin # type: ignore from flask_login import UserMixin # type: ignore
from libs.helper import generate_string from libs.helper import generate_string
from models.enums import CreatedByRole from models.enums import CreatedByRole
from models.workflow import WorkflowRunStatus from models.workflow import WorkflowRun, WorkflowRunStatus
from sqlalchemy import Float, func, text from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@ -1455,6 +1455,15 @@ class EndUser(UserMixin, db.Model): # type: ignore[name-defined]
"memory": self.memory, "memory": self.memory,
} }
@property
def total_messages_count(self):
return (
db.session.query(Message)
.filter(Message.from_end_user_id == self.id)
.filter(Message.organization_id == self.organization_id)
.count()
)
class Site(db.Model): # type: ignore[name-defined] class Site(db.Model): # type: ignore[name-defined]
__tablename__ = "sites" __tablename__ = "sites"
@ -1857,16 +1866,36 @@ class UserGeneratedImage(db.Model): # type: ignore[name-defined]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
end_user_id = db.Column(StringUUID, nullable=False) end_user_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(StringUUID, nullable=False) workflow_run_id = db.Column(StringUUID, nullable=False) # related generation id
image_url = db.Column(db.Text, nullable=False)
content_type = db.Column(db.String(255), nullable=False) # 'self_message' or 'summary_advice' content_type = db.Column(db.String(255), nullable=False) # 'self_message' or 'summary_advice'
text_content = db.Column(db.Text, nullable=False) image_url = db.Column(db.Text, nullable=True)
text_content = db.Column(db.Text, nullable=True)
raw_content = db.Column(db.JSON, nullable=True) # save raw llm outputs
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
updated_at = db.Column(
db.DateTime, nullable=False, server_default=db.func.current_timestamp(), onupdate=db.func.current_timestamp()
)
@property @property
def end_user(self): def end_user(self):
return db.session.query(EndUser).filter(EndUser.id == self.end_user_id).first() return db.session.query(EndUser).filter(EndUser.id == self.end_user_id).first()
@property @property
def conversation(self): def status(self):
return db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
if workflow_run is None:
return WorkflowRunStatus.FAILED.value
return workflow_run.status
def refresh_status(self):
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
if workflow_run is None:
return
if workflow_run.status == WorkflowRunStatus.SUCCEEDED and self.image_url is None:
self.image_url = workflow_run.outputs.get("image_url")
self.text_content = workflow_run.outputs.get("text_content")
db.session.commit()

@ -1,169 +1,126 @@
import json import json
import logging from enum import Enum
import os
import random from configs import dify_config
from datetime import datetime from core.app.entities.app_invoke_entities import InvokeFrom
from typing import Any, Dict, List, Optional, Tuple, Union
# Import the UserGeneratedImage model from our controller module
# This is a bit of a circular import, but it's the simplest solution for now
from controllers.service_api_with_auth.app.image_generate import UserGeneratedImage
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import CreatedByRole from libs.helper import RateLimiter
from models.model import App, Conversation, EndUser, Message, UploadFile from libs.infinite_scroll_pagination import MultiPagePagination
from sqlalchemy.orm import Session from models.model import App, EndUser, Message, UserGeneratedImage
from services.app_generate_service import AppGenerateService
# Configure logging # define string enum for content_type
logger = logging.getLogger(__name__) class ContentType(str, Enum):
SELF_MESSAGE = "self_message"
SUMMARY_ADVICE = "summary_advice"
class ImageGenerationService: class ImageGenerationService:
@staticmethod
def generate_motivational_text(conversation_id: str, content_type: str) -> str:
"""
Generate motivational text based on conversation history.
Args:
conversation_id: The ID of the conversation
content_type: Type of content to generate ('self_message' or 'summary_advice')
Returns:
str: Generated text content
"""
# In a real implementation, this would call a large language model
# Here we'll just use placeholders based on the content type
with Session(db.engine) as session:
# Get the last few messages from the conversation to understand context
messages = (
session.query(Message)
.filter(Message.conversation_id == conversation_id)
.order_by(Message.created_at.desc())
.limit(20)
.all()
)
# Reverse to get chronological order generate_image_rate_limiter = RateLimiter(
messages.reverse() prefix="generate_image_rate_limit", max_attempts=dify_config.IMAGE_GENERATION_DAILY_LIMIT, time_window=86400 * 1
)
# Extract conversation context
context = "\n".join([f"User: {msg.query}\nAI: {msg.answer}" for msg in messages])
# In production, you would pass this context to a language model
# For demonstration, we'll return placeholder text
if content_type == "self_message":
sample_messages = [
"You've got this! Take one small step today.",
"Remember your strength - you've overcome challenges before.",
"Be kind to yourself today, you deserve it.",
"Your feelings are valid, and you have the power to work through them.",
"Small progress is still progress. Celebrate your wins today.",
]
return random.choice(sample_messages)
else: # summary_advice
sample_advice = [
"Based on our conversation, I notice you tend to be hard on yourself. Try practicing self-compassion by speaking to yourself as you would to a friend.",
"I've observed that you often describe feeling overwhelmed. Breaking tasks into smaller steps might help manage these feelings better.",
"In our discussions, I noticed patterns of negative self-talk. Consider challenging these thoughts by asking 'Is this really true?' when they arise.",
"From our conversations, it seems you might benefit from more self-care routines. Even 5 minutes of mindfulness daily could make a difference.",
"You've mentioned feeling anxious in social situations. Progressive exposure to small social interactions might help build confidence over time.",
]
return random.choice(sample_advice)
@staticmethod @staticmethod
def generate_background_image() -> str: def generate_image(end_user: EndUser, content_type: ContentType) -> str:
"""
Generate or select a background image.
In a real implementation, this might call an image generation API
or select from pre-generated images.
Returns:
str: URL of the generated/selected image
"""
# In a real implementation, this would integrate with an image generation API
# or select from a pool of pre-generated images
# For this example, we'll return a placeholder
placeholder_images = [
"https://example.com/background1.jpg",
"https://example.com/background2.jpg",
"https://example.com/background3.jpg",
"https://example.com/background4.jpg",
"https://example.com/background5.jpg",
]
return random.choice(placeholder_images)
@staticmethod if ImageGenerationService.generate_image_rate_limiter.is_rate_limited(end_user.id):
def overlay_text_on_image(image_url: str, text: str) -> str: raise Exception("Image generation rate limit exceeded")
"""
Overlay text on the image.
In a real implementation, this would use image processing libraries. if dify_config.IMAGE_GENERATION_APP_ID is None:
raise Exception("Image generation app id is not set")
Args: image_generation_app_model = App.query.filter(App.id == dify_config.IMAGE_GENERATION_APP_ID).first()
image_url: URL of the background image if image_generation_app_model is None:
text: Text to overlay on the image raise Exception("Image generation app model is not found")
Returns: user_profile = end_user.extra_profile
str: URL of the final image with text recent_messages = (
""" db.session.query(Message)
# In a real implementation, this would use image processing libraries .filter(Message.app_id == end_user.app_id, Message.from_end_user_id == end_user.id)
# like Pillow to overlay text on the image .order_by(Message.created_at.desc())
.limit(10)
.all()
)
# For this example, we'll just return the same URL recent_messages = [f"user: {message.query}\n\nassistant: {message.answer}" for message in recent_messages]
# In production, you would process the image and save it to storage
return image_url args = {
"inputs": {
"user_profile": json.dumps(user_profile),
"recent_messages": "\n\n".join(recent_messages),
"image_type": content_type,
}
}
@staticmethod
def process_image_generation_request(
app_id: str, conversation_id: str, end_user_id: str, content_type: str
) -> Optional[str]:
"""
Process an image generation request.
Args:
app_id: The ID of the app
conversation_id: The ID of the conversation
end_user_id: The ID of the end user
content_type: Type of content to generate ('self_message' or 'summary_advice')
Returns:
Optional[str]: ID of the generated image if successful, None otherwise
"""
try: try:
# 1. Generate motivational text based on conversation history response = AppGenerateService.generate(
text_content = ImageGenerationService.generate_motivational_text( app_model=image_generation_app_model,
conversation_id=conversation_id, content_type=content_type user=end_user,
args=args,
invoke_from=InvokeFrom.SCHEDULER,
streaming=False,
) )
# 2. Generate or select a background image if not isinstance(response, dict):
image_url = ImageGenerationService.generate_background_image() raise Exception("Failed to generate image")
# load workflow id and save it to db for futher fetch image status
workflow_run_id = response.get("workflow_run_id")
raw_content = response.get("data", {}).get("outputs", {})
# parse url from response.data.outputs
image_objs = raw_content.get("files")
url = None
for image_obj in image_objs:
if image_obj.get("type") == "image":
url = image_obj.get("url")
break
# 3. Overlay text on the image if url is None:
final_image_url = ImageGenerationService.overlay_text_on_image(image_url=image_url, text=text_content) raise Exception("Failed to generate image")
# 4. Create and save the user generated image record text_content = raw_content.get("text")
with Session(db.engine) as session:
new_image = UserGeneratedImage(
app_id=app_id,
end_user_id=end_user_id,
conversation_id=conversation_id,
image_url=final_image_url,
content_type=content_type,
text_content=text_content,
)
session.add(new_image) user_generated_image = UserGeneratedImage(
session.commit() app_id=end_user.app_id,
end_user_id=end_user.id,
workflow_run_id=workflow_run_id,
content_type=content_type,
image_url=url,
text_content=text_content,
raw_content=raw_content,
)
image_id = str(new_image.id) db.session.add(user_generated_image)
logger.info(f"Generated image {image_id} for user {end_user_id}") db.session.commit()
return image_id return user_generated_image.id
except Exception as e: except Exception as e:
logger.error(f"Error generating image: {str(e)}") raise Exception(f"Failed to generate image: {e}")
return None
@staticmethod
def pagination_image_list(end_user: EndUser, limit: int, offset: int) -> MultiPagePagination:
query = (
db.session.query(UserGeneratedImage)
.filter(UserGeneratedImage.app_id == end_user.app_id, UserGeneratedImage.end_user_id == end_user.id)
.order_by(UserGeneratedImage.created_at.desc())
)
total_count = query.count()
images = query.limit(limit).offset(offset).all()
return MultiPagePagination(data=images, total=total_count)
@staticmethod
def get_image_by_id(image_id: str) -> UserGeneratedImage:
image = db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first()
if image is None:
raise Exception("Image not found")
return image

Loading…
Cancel
Save